def train_with_pruning_callback( tmpdir, parameters_to_prune=False, use_global_unstructured=False, pruning_fn="l1_unstructured", use_lottery_ticket_hypothesis=False, accelerator=None, gpus=None, num_processes=1, ): model = TestModel() # Weights are random. None is 0 assert torch.all(model.layer.mlp_2.weight != 0) pruning_kwargs = { "pruning_fn": pruning_fn, "amount": 0.3, "use_global_unstructured": use_global_unstructured, "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis, "verbose": 1, } if parameters_to_prune: pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")] else: if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"): pruning_kwargs["parameter_names"] = ["weight"] else: pruning_kwargs["parameter_names"] = ["weight", "bias"] if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"): pruning_kwargs["pruning_dim"] = 0 if pruning_fn == "ln_structured": pruning_kwargs["pruning_norm"] = 1 # Misconfiguration checks if isinstance(pruning_fn, str) and pruning_fn.endswith( "_structured") and use_global_unstructured: with pytest.raises( MisconfigurationException, match="is supported with `use_global_unstructured=True`"): ModelPruning(**pruning_kwargs) return if ModelPruning._is_pruning_method( pruning_fn) and not use_global_unstructured: with pytest.raises(MisconfigurationException, match="currently only supported with"): ModelPruning(**pruning_kwargs) return pruning = ModelPruning(**pruning_kwargs) trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, weights_summary=None, checkpoint_callback=False, logger=False, limit_train_batches=10, limit_val_batches=2, max_epochs=10, accelerator=accelerator, gpus=gpus, num_processes=num_processes, callbacks=pruning, ) trainer.fit(model) trainer.test(model) if not accelerator: # Check some have been pruned assert torch.any(model.layer.mlp_2.weight == 0)
def test_disabled_checkpointing(tmpdir): # no callback trainer = Trainer(max_epochs=3, enable_checkpointing=False) assert not trainer.checkpoint_callbacks trainer.fit(BoringModel()) assert not trainer.checkpoint_callbacks
def test_gradient_accumulation_scheduling(): """ Test grad accumulation by the freq of optimizer updates """ # test incorrect configs with pytest.raises(IndexError): assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6}) assert Trainer(accumulate_grad_batches={-2: 3}) with pytest.raises(TypeError): assert Trainer(accumulate_grad_batches={}) assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]]) assert Trainer(accumulate_grad_batches={1: 2, 3.: 4}) assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5}) # test optimizer call freq matches scheduler def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i): # only test the first 12 batches in epoch if batch_nb < 12: if epoch_nb == 0: # reset counter when starting epoch if batch_nb == 0: self.prev_called_batch_nb = 0 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 1 assert batch_nb == self.prev_called_batch_nb self.prev_called_batch_nb += 1 elif 1 <= epoch_nb <= 2: # reset counter when starting epoch if batch_nb == 1: self.prev_called_batch_nb = 1 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 2 assert batch_nb == self.prev_called_batch_nb self.prev_called_batch_nb += 2 else: if batch_nb == 3: self.prev_called_batch_nb = 3 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 4 assert batch_nb == self.prev_called_batch_nb self.prev_called_batch_nb += 3 optimizer.step() # clear gradients optimizer.zero_grad() hparams = get_hparams() model = LightningTestModel(hparams) schedule = {1: 2, 3: 4} trainer = Trainer(accumulate_grad_batches=schedule, train_percent_check=0.1, val_percent_check=0.1, max_nb_epochs=4) # for the test trainer.optimizer_step = optimizer_step model.prev_called_batch_nb = 0 trainer.fit(model)
def test_metric_collections(tmpdir): """This test ensures the metric attribute is properly found even with complex nested metric structure.""" class TestModel(BoringModel): def __init__(self): super().__init__() self.metrics_list = ModuleList([DummyMetric() for _ in range(2)]) self.metrics_dict = ModuleDict({ "a": DummyMetric(), "b": DummyMetric() }) self.metrics_collection_dict = MetricCollection({ "a": DummyMetric(), "b": DummyMetric() }) self.metrics_collection_dict_nested = ModuleDict({ "a": ModuleList([ModuleDict({"b": DummyMetric()}), DummyMetric()]) }) def training_step(self, batch, batch_idx): loss = super().training_step(batch, batch_idx) self.metrics_list[0](batch_idx) self.metrics_list[1](batch_idx) self.metrics_dict["a"](batch_idx) self.metrics_dict["b"](batch_idx) self.metrics_collection_dict["a"](batch_idx) self.metrics_collection_dict["b"](batch_idx) self.metrics_collection_dict_nested["a"][0]["b"](batch_idx) self.metrics_collection_dict_nested["a"][1](batch_idx) self.log("a", self.metrics_list[0]) self.log("b", self.metrics_list[1]) self.log("c", self.metrics_dict["a"]) self.log("d", self.metrics_dict["b"]) self.log("e", self.metrics_collection_dict["a"]) self.log("f", self.metrics_collection_dict["b"]) self.log("g", self.metrics_collection_dict_nested["a"][0]["b"]) self.log("h", self.metrics_collection_dict_nested["a"][1]) return loss def on_train_epoch_end(self) -> None: results = self.trainer.fit_loop.epoch_loop._results assert results[ "training_step.a"].meta.metric_attribute == "metrics_list.0" assert results[ "training_step.b"].meta.metric_attribute == "metrics_list.1" assert results[ "training_step.c"].meta.metric_attribute == "metrics_dict.a" assert results[ "training_step.d"].meta.metric_attribute == "metrics_dict.b" assert results[ "training_step.e"].meta.metric_attribute == "metrics_collection_dict.a" assert results[ "training_step.f"].meta.metric_attribute == "metrics_collection_dict.b" assert results[ "training_step.g"].meta.metric_attribute == "metrics_collection_dict_nested.a.0.b" assert results[ "training_step.h"].meta.metric_attribute == "metrics_collection_dict_nested.a.1" model = TestModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=0) trainer.fit(model)
def main(cfg: DictConfig): cur_dir = hydra.utils.get_original_cwd() os.chdir(cur_dir) # Random Seed seed_everything(cfg.train.seed) # Model #################################################################### net = ENet(model_name=cfg.train.model_name) transform = ImageTransform(img_size=cfg.data.img_size) # Comet.ml experiment = Experiment(api_key=cfg.comet_ml.api_key, project_name=cfg.comet_ml.project_name) # Log Parameters experiment.log_parameters(dict(cfg.exp)) experiment.log_parameters(dict(cfg.data)) experiment.log_parameters(dict(cfg.train)) # Log Model Graph experiment.set_model_graph(str(net)) # Lightning Module ######################################################### model = LightningSystem(net, cfg, experiment) datamodule = DataModule(data_dir, cfg, transform, cv) checkpoint_callback = ModelCheckpoint(filepath='./checkpoint', save_top_k=1, verbose=True, monitor='avg_val_loss', mode='min', prefix=cfg.exp.exp_name + '_') trainer = Trainer(logger=False, max_epochs=cfg.train.epoch, checkpoint_callback=checkpoint_callback, gpus=1) # Train & Test ############################################################ # Train trainer.fit(model, datamodule=datamodule) experiment.log_metric('best_auc', model.best_auc) checkpoint_path = glob.glob(f'./checkpoint/{cfg.exp.exp_name}_*.ckpt')[0] experiment.log_asset(file_data=checkpoint_path) # Test for i in range(test_num): trainer.test(model) # Submit sub_list = glob.glob(f'submission_{cfg.exp.exp_name}*.csv') _ = summarize_submit(sub_list, experiment, filename=f'sub_{cfg.exp.exp_name}.csv') # oof oof_dataset = datamodule.oof_dataset oof_dataloader = DataLoader(oof_dataset, batch_size=cfg.train.batch_size, pin_memory=False, shuffle=False, drop_last=False) for i in range(10): trainer.test(model, test_dataloaders=oof_dataloader) # Submit sub_list = glob.glob('submission*.csv') _ = summarize_submit(sub_list, experiment, filename=f'oof_{cfg.exp.exp_name}.csv')
def test_cpu_restore_training(): """ Verify continue training session on CPU :return: """ hparams = get_hparams() model = LightningTestModel(hparams) save_dir = init_save_dir() # exp file to get meta test_exp_version = 10 exp = get_exp(False, version=test_exp_version) exp.argparse(hparams) exp.save() trainer_options = dict(max_nb_epochs=2, val_check_interval=0.50, val_percent_check=0.2, train_percent_check=0.2, experiment=exp, checkpoint_callback=ModelCheckpoint(save_dir)) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) real_global_epoch = trainer.current_epoch # traning complete assert result == 1, 'amp + ddp model failed to complete' # wipe-out trainer and model # retrain with not much data... this simulates picking training back up after slurm # we want to see if the weights come back correctly new_exp = get_exp(False, version=test_exp_version) trainer_options = dict( max_nb_epochs=2, val_check_interval=0.50, val_percent_check=0.2, train_percent_check=0.2, experiment=new_exp, checkpoint_callback=ModelCheckpoint(save_dir), ) trainer = Trainer(**trainer_options) model = LightningTestModel(hparams) # set the epoch start hook so we can predict before the model does the full training def assert_good_acc(): assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model trainer.model.eval() run_prediction(trainer.val_dataloader, trainer.model) model.on_sanity_check_start = assert_good_acc # by calling fit again, we trigger training, loading weights from the cluster # and our hook to predict using current model before any more weight updates trainer.fit(model) clear_save_dir()
def test_trainer_callback_system(tmpdir): """Test the callback system.""" class CurrentTestModel( LightTrainDataloader, LightTestMixin, LightValidationMixin, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) def _check_args(trainer, pl_module): assert isinstance(trainer, Trainer) assert isinstance(pl_module, LightningModule) class TestCallback(Callback): def __init__(self): super().__init__() self.on_init_start_called = False self.on_init_end_called = False self.on_epoch_start_called = False self.on_epoch_end_called = False self.on_batch_start_called = False self.on_batch_end_called = False self.on_train_start_called = False self.on_train_end_called = False self.on_validation_start_called = False self.on_validation_end_called = False self.on_test_start_called = False self.on_test_end_called = False def on_init_start(self, trainer): assert isinstance(trainer, Trainer) self.on_init_start_called = True def on_init_end(self, trainer): assert isinstance(trainer, Trainer) self.on_init_end_called = True def on_epoch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_epoch_start_called = True def on_epoch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_epoch_end_called = True def on_batch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_start_called = True def on_batch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_end_called = True def on_train_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_train_start_called = True def on_train_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_train_end_called = True def on_validation_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_validation_start_called = True def on_validation_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_validation_end_called = True def on_test_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_test_start_called = True def on_test_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_test_end_called = True test_callback = TestCallback() trainer_options = { 'callbacks': [test_callback], 'max_epochs': 1, 'val_percent_check': 0.1, 'train_percent_check': 0.2, 'show_progress_bar': False } assert not test_callback.on_init_start_called assert not test_callback.on_init_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_validation_start_called assert not test_callback.on_validation_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called # fit model trainer = Trainer(**trainer_options) assert trainer.callbacks[0] == test_callback assert test_callback.on_init_start_called assert test_callback.on_init_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_validation_start_called assert not test_callback.on_validation_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called trainer.fit(model) assert test_callback.on_init_start_called assert test_callback.on_init_end_called assert test_callback.on_epoch_start_called assert test_callback.on_epoch_start_called assert test_callback.on_batch_start_called assert test_callback.on_batch_end_called assert test_callback.on_train_start_called assert test_callback.on_train_end_called assert test_callback.on_validation_start_called assert test_callback.on_validation_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called trainer.test() assert test_callback.on_test_start_called assert test_callback.on_test_end_called
class LightningTrainer(BaseTrainer): def __init__(self, config: DictConfig): super().__init__(config) self.trainer = None self.trainer_config = self.config.trainer.params self.data_module = None def load(self): super().load() self._calculate_max_updates() self._load_loggers() self._load_trainer() def _load_trainer(self): lightning_params = self.trainer_config with omegaconf.open_dict(lightning_params): lightning_params.pop("max_steps") lightning_params.pop("max_epochs") lightning_params_dict = OmegaConf.to_container(lightning_params, resolve=True) self.trainer = Trainer(callbacks=self._callbacks, max_steps=self._max_updates, default_root_dir=get_mmf_env(key="log_dir"), **lightning_params_dict) def configure_device(self) -> None: pass def configure_seed(self) -> None: seed = self.config.training.seed seed_everything(seed) def _load_loggers(self) -> None: self.tb_writer = None if self.training_config.tensorboard: # TODO: @sash PL logger upgrade log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir) def load_datasets(self) -> None: logger.info("Loading datasets") data_module = MultiDataModule(self.config) self.data_module = data_module self.train_loader = data_module.train_dataloader() self.val_loader = data_module.val_dataloader() self.test_loader = data_module.test_dataloader() def load_model(self) -> None: logger.info("Loading models") attributes = self.config.model_config[self.config.model] if isinstance(attributes, str): attributes = self.config.model_config[attributes] with omegaconf.open_dict(attributes): attributes.model = self.config.model self.model = build_model(attributes) self.model.is_pl_enabled = True self.model.build_meters(self.run_type) def load_optimizer(self) -> None: logger.info("Loading optimizer: noop for lightning") def load_metrics(self) -> None: logger.info("Loading metrics") metrics = self.config.evaluation.get("metrics", []) # moved metrics into the model object self.model.metrics = Metrics(metrics) def configure_callbacks(self) -> None: self._callbacks = [LightningLoopCallback(self)] def train(self) -> None: logger.info("===== Model =====") logger.info(self.model) print_model_parameters(self.model) logger.info("Starting training...") if "train" not in self.run_type: self.inference() return self.trainer.fit(self.model, self.data_module) # TODO: Look for a better way to hook this self.data_module.teardown() def inference(self) -> None: logger.info("Starting inference...") # TODO: @sash coming soon pass def _calculate_max_updates(self) -> None: self._max_updates = self.trainer_config.max_steps self._max_epochs = self.trainer_config.max_epochs if self._max_updates is None and self._max_epochs is None: raise ValueError( "Neither max_updates nor max_epochs is specified.") self._max_updates, max_epochs = get_max_updates( self._max_updates, self._max_epochs, self.train_loader, self.trainer_config.accumulate_grad_batches, ) self._max_epochs = math.ceil(max_epochs) return self._max_updates
def test_deepspeed_fp32_works(tmpdir): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, gpus=1, strategy="deepspeed_stage_3", fast_dev_run=True) trainer.fit(model)
def test_ckpt_metric_names_results(tmpdir): class ResultLog(BoringModel): def training_step(self, batch, batch_idx): y_hat = self(batch) # calculate loss loss_val = self.loss(batch, y_hat) log_val = loss_val # alternate between tensors and scalars for "log" and "progress_bar" if batch_idx % 2 == 0: log_val = log_val.item() result = pl.core.step_result.TrainResult(loss_val) result.log('some_val', log_val * log_val, prog_bar=True, logger=False) result.log('train_some_val', log_val * log_val) return result def validation_step(self, batch, batch_idx): y_hat = self(batch) loss_val = self.loss(batch, y_hat) # acc labels_hat = torch.argmax(y_hat, dim=1) val_acc = torch.sum(batch == labels_hat).item() / (len(batch) * 1.0) val_acc = torch.tensor(val_acc).type_as(batch) result = pl.core.step_result.EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val) result.log_dict({ 'val_loss': loss_val, 'val_acc': val_acc, }) return result model = ResultLog() model.training_step_end = None model.training_epoch_end = None model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, gradient_clip_val=1.0, overfit_batches=0.20, progress_bar_refresh_rate=0, limit_train_batches=0.01, limit_val_batches=0.01, callbacks=[ ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename="{val_loss:.2f}") ], ) trainer.fit(model) # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(tmpdir) ckpts = [x for x in ckpts if "val_loss" in x] assert len(ckpts) == 1 val = re.sub("[^0-9.]", "", ckpts[0]) assert len(val) > 3
def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" hparams = tutils.get_default_hparams() model = EvalModelTemplate(hparams) # logger file to get meta logger = tutils.get_default_logger(tmpdir) version = logger.version # fit model trainer = Trainer( max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir) ) result = trainer.fit(model) real_global_step = trainer.global_step # traning complete assert result == 1, 'cpu model failed to complete' # predict with trained model before saving # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: for batch in dataloader: break x, y = batch x = x.view(x.size(0), -1) model.eval() pred_before_saving = model(x) # test HPC saving # simulate snapshot on slurm saved_filepath = trainer.hpc_save(tmpdir, logger) assert os.path.exists(saved_filepath) # new logger file to get meta logger = tutils.get_default_logger(tmpdir, version=version) trainer = Trainer( max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir), ) model = EvalModelTemplate(hparams) # set the epoch start hook so we can predict before the model does the full training def assert_pred_same(): assert trainer.global_step == real_global_step and trainer.global_step > 0 # predict with loaded model to make sure answers are the same trainer.model.eval() new_pred = trainer.model(x) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 model.on_epoch_start = assert_pred_same # by calling fit again, we trigger training, loading weights from the cluster # and our hook to predict using current model before any more weight updates trainer.fit(model)
def main(): if pre_model == False: # Data Module and Model dm = DataModule(bs) dm.prepare_data() model = FCN(num_classes, learning_rate) # Running our Model trainer = Trainer(max_epochs=epochs, fast_dev_run=False, gpus=1, profiler=False, progress_bar_refresh_rate=1, logger=tboard_logger) trainer.fit(model, dm) if pre_model == True: model = models.segmentation.fcn_resnet101(pretrained=True).eval() transform_VOC = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) full_dataset = torchvision.datasets.VOCSegmentation( "/projects/brfi3983/", image_set='train', download=True, transform=transform_VOC) test_dataset = torchvision.datasets.VOCSegmentation( "/projects/brfi3983/", image_set='val', download=True, transform=transform_VOC) # # Loading our custom image and transforming it to our pretrained network (uncomment which image to use) # image = Image.open('bird.jpg') # image = Image.open('person.jpg') image = Image.open('dog.jpg') trf = T.Compose([ T.Resize(800), T.CenterCrop(720), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Convert single image to single batch inp = trf(image).unsqueeze(0) out = model(inp)['out'] print(f'Output shape: {out.shape}') # Seeing which classes are most dominant across its depth om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy() # Plotting Class distribution classes = np.unique(om) plt.hist(om, bins=num_classes) plt.title('Occurrences of Classes') plt.xlabel('Class') plt.ylabel('Count') # Plotting original image across segmented image fig, (ax1, ax2) = plt.subplots(1, 2) plt.suptitle('Image Segmentation') ax1.set_title('Original Image') ax1.imshow(image) ax2.set_title('Segmented Image') ax2.imshow(om, cmap='YlOrBr') plt.show()
def test_write_predictions(tmpdir, option: int, do_train: bool, gpus: int): class CustomBoringModel(BoringModel): def test_step(self, batch, batch_idx, optimizer_idx=None): output = self(batch) test_loss = self.loss(batch, output) self.log('test_loss', test_loss) batch_size = batch.size(0) lst_of_str = [ random.choice(['dog', 'cat']) for i in range(batch_size) ] lst_of_int = [random.randint(500, 1000) for i in range(batch_size)] lst_of_lst = [[x] for x in lst_of_int] lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)] prediction_file = getattr(self, 'prediction_file', 'predictions.pt') lazy_ids = torch.arange(batch_idx * batch_size, batch_idx * batch_size + batch_size) # Base if option == 0: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('preds', output, prediction_file) # Check mismatching tensor len elif option == 1: self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file) self.write_prediction('preds', output, prediction_file) # write multi-dimension elif option == 2: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('preds', output, prediction_file) self.write_prediction('x', batch, prediction_file) # write str list elif option == 3: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_str, prediction_file) # write int list elif option == 4: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_int, prediction_file) # write nested list elif option == 5: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_lst, prediction_file) # write dict list elif option == 6: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_dict, prediction_file) elif option == 7: self.write_prediction_dict({ 'idxs': lazy_ids, 'preds': output }, prediction_file) prediction_file = Path(tmpdir) / 'predictions.pt' dm = BoringDataModule() model = CustomBoringModel() model.test_epoch_end = None model.prediction_file = prediction_file.as_posix() trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, weights_summary=None, deterministic=True, gpus=gpus, ) # Prediction file shouldn't exist yet because we haven't done anything assert not prediction_file.exists() if do_train: trainer.fit(model, dm) assert trainer.state.finished, f"Training failed with {trainer.state}" trainer.test(datamodule=dm) else: trainer.test(model, datamodule=dm) # check prediction file now exists and is of expected length assert prediction_file.exists() predictions = torch.load(prediction_file) assert len(predictions) == len(dm.random_test)
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument( "--encoder", default=None, type=str, help= "encoder to initialize. Can accept SimCLR model checkpoint or just encoder name in from encoders_dali" ) parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--num_workers", default=1, type=int, help="number of workers to use to fetch data") parser.add_argument( "--hidden_dims", default=128, type=int, help= "hidden dimensions in classification layer added onto model for finetuning" ) parser.add_argument("--epochs", default=400, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model") parser.add_argument( "--patience", default=-1, type=int, help= "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping." ) parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument( "--withhold_split", default=0, type=float, help= "decimal from 0-1 representing how much of the training data to withold from either training or validation. Used for experimenting with labels neeeded" ) parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--log_name", type=str, help="name of model to log on wandb and locally") parser.add_argument( "--online_eval", default=False, type=bool, help="Do finetuning on model if labels are provided as a sanity check") args = parser.parse_args() DATA_PATH = args.DATA_PATH batch_size = args.batch_size num_workers = args.num_workers hidden_dims = args.hidden_dims epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split withhold = args.withhold_split gpus = args.gpus encoder = args.encoder log_name = 'SIMCLR_SSL_' + args.log_name + '.ckpt' online_eval = args.online_eval wandb_logger = WandbLogger(name=log_name, project='SpaceForce') checkpointed = '.ckpt' in encoder if checkpointed: print('Resuming SSL Training from Model Checkpoint') try: model = SIMCLR.load_from_checkpoint(checkpoint_path=encoder) embedding_size = model.embedding_size except Exception as e: print(e) print( 'invalid checkpoint to initialize SIMCLR. This checkpoint needs to include the encoder and projection and is of the SIMCLR class from this library. Will try to initialize just the encoder' ) checkpointed = False elif not checkpointed: encoder, embedding_size = load_encoder(encoder) model = SIMCLR(encoder=encoder, embedding_size=embedding_size, gpus=gpus, epochs=epochs, DATA_PATH=DATA_PATH, withhold=withhold, batch_size=batch_size, val_split=val_split, hidden_dims=hidden_dims, train_transform=SimCLRTrainDataTransform, val_transform=SimCLRTrainDataTransform, num_workers=num_workers, lr=lr) online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=embedding_size, num_classes=model.num_classes, dataset='None') cbs = [] backend = 'dp' if patience > 0: cb = EarlyStopping('val_loss', patience=patience) cbs.append(cb) if online_eval: cbs.append(online_evaluator) backend = 'ddp' trainer = Trainer( gpus=gpus, max_epochs=epochs, progress_bar_refresh_rate=5, callbacks=cbs, distributed_backend=f'{backend}' if args.gpus > 1 else None, logger=wandb_logger, enable_pl_optimizer=True) print('USING BACKEND______________________________ ', backend) trainer.fit(model) Path(f"./models/SSL").mkdir(parents=True, exist_ok=True) trainer.save_checkpoint(f"./models/SSL/{log_name}")
def test_cpu_slurm_save_load(): """ Verify model save/load/checkpoint on CPU :return: """ hparams = get_hparams() model = LightningTestModel(hparams) save_dir = init_save_dir() # exp file to get meta exp = get_exp(False) exp.argparse(hparams) exp.save() cluster_a = SlurmCluster() trainer_options = dict(max_nb_epochs=1, cluster=cluster_a, experiment=exp, checkpoint_callback=ModelCheckpoint(save_dir)) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) real_global_step = trainer.global_step # traning complete assert result == 1, 'amp + ddp model failed to complete' # predict with trained model before saving # make a prediction for batch in model.test_dataloader: break x, y = batch x = x.view(x.size(0), -1) model.eval() pred_before_saving = model(x) # test registering a save function trainer.enable_auto_hpc_walltime_manager() # test HPC saving # simulate snapshot on slurm saved_filepath = trainer.hpc_save(save_dir, exp) assert os.path.exists(saved_filepath) # wipe-out trainer and model # retrain with not much data... this simulates picking training back up after slurm # we want to see if the weights come back correctly continue_tng_hparams = get_hparams(continue_training=True, hpc_exp_number=cluster_a.hpc_exp_number) trainer_options = dict( max_nb_epochs=1, cluster=SlurmCluster(continue_tng_hparams), experiment=exp, checkpoint_callback=ModelCheckpoint(save_dir), ) trainer = Trainer(**trainer_options) model = LightningTestModel(hparams) # set the epoch start hook so we can predict before the model does the full training def assert_pred_same(): assert trainer.global_step == real_global_step and trainer.global_step > 0 # predict with loaded model to make sure answers are the same trainer.model.eval() new_pred = trainer.model(x) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 model.on_epoch_start = assert_pred_same # by calling fit again, we trigger training, loading weights from the cluster # and our hook to predict using current model before any more weight updates trainer.fit(model) clear_save_dir()
def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) trainer_options = dict( max_epochs=1, gpus=2, distributed_backend='dp', default_root_dir=tmpdir, ) # get logger logger = tutils.get_default_logger(tmpdir) # exp file to get weights # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # add these to the trainer options trainer_options['logger'] = logger trainer_options['checkpoint_callback'] = checkpoint # fit model trainer = Trainer(**trainer_options) trainer.is_slurm_managing_tasks = True result = trainer.fit(model) # track epoch before saving. Increment since we finished the current epoch, don't want to rerun real_global_epoch = trainer.current_epoch + 1 # correct result and ok accuracy assert result == 1, 'amp + dp model failed to complete' # --------------------------- # HPC LOAD/SAVE # --------------------------- # save trainer.checkpoint_connector.hpc_save(tmpdir, logger) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) trainer_options['logger'] = new_logger trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir) trainer_options['limit_train_batches'] = 0.5 trainer_options['limit_val_batches'] = 0.2 trainer_options['max_epochs'] = 1 new_trainer = Trainer(**trainer_options) # set the epoch start hook so we can predict before the model does the full training def assert_good_acc(): assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() dataloader = trainer.train_dataloader tpipes.run_prediction(dataloader, dp_model, dp=True) # new model model = EvalModelTemplate(**hparams) model.on_train_start = assert_good_acc # fit new model which should load hpc weights new_trainer.fit(model) # test freeze on gpu model.freeze() model.unfreeze()
def test_amp_gpu_ddp_slurm_managed(): """ Make sure DDP + AMP work :return: """ if not torch.cuda.is_available(): warnings.warn('test_amp_gpu_ddp cannot run.' ' Rerun on a GPU node to run this test') return if not torch.cuda.device_count() > 1: warnings.warn('test_amp_gpu_ddp cannot run.' ' Rerun on a node with 2+ GPUs to run this test') return # simulate setting slurm flags os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0]) os.environ['SLURM_LOCALID'] = str(0) hparams = get_hparams() model = LightningTestModel(hparams) trainer_options = dict(progress_bar=True, max_nb_epochs=1, gpus=[0], distributed_backend='ddp', use_amp=True) save_dir = init_save_dir() # exp file to get meta exp = get_exp(False) exp.argparse(hparams) exp.save() # exp file to get weights checkpoint = ModelCheckpoint(save_dir) # add these to the trainer options trainer_options['checkpoint_callback'] = checkpoint trainer_options['experiment'] = exp # fit model trainer = Trainer(**trainer_options) trainer.is_slurm_managing_tasks = True result = trainer.fit(model) # correct result and ok accuracy assert result == 1, 'amp + ddp model failed to complete' # test root model address assert trainer.resolve_root_node_address('abc') == 'abc' assert trainer.resolve_root_node_address('abc[23]') == 'abc23' assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23' assert trainer.resolve_root_node_address( 'abc[23-24, 45-40, 40]') == 'abc23' # test model loading with a map_location map_location = 'cuda:1' pretrained_model = load_model(exp, save_dir, True, map_location) # test model preds run_prediction(model.test_dataloader, pretrained_model) if trainer.use_ddp: # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers( ) # test HPC loading / saving trainer.hpc_save(save_dir, exp) trainer.hpc_load(save_dir, on_gpu=True) # test freeze on gpu model.freeze() model.unfreeze() clear_save_dir()
def test_lr_scheduler_step_hook(tmpdir): """Test that custom lr scheduler works and `lr_scheduler_step` is called at appropriate time.""" class CustomEpochScheduler: def __init__(self, optimizer): self.optimizer = optimizer def step(self, epoch): ... def state_dict(self): ... def load_state_dict(self, state_dict): ... class CustomBoringModel(BoringModel): def training_step(self, batch, batch_idx, optimizer_idx=0): return super().training_step(batch, batch_idx) def lr_scheduler_step(self, scheduler, optimizer_idx, metric): # step-level if optimizer_idx == 0: super().lr_scheduler_step(scheduler, optimizer_idx, metric) # epoch-level elif optimizer_idx == 1: scheduler.step(epoch=self.current_epoch) def configure_optimizers(self): opt1 = torch.optim.SGD(self.layer.parameters(), lr=1e-2) lr_scheduler1 = { "scheduler": torch.optim.lr_scheduler.StepLR(opt1, step_size=1), "interval": "step" } opt2 = torch.optim.SGD(self.layer.parameters(), lr=1e-2) lr_scheduler2 = CustomEpochScheduler(opt2) return { "optimizer": opt1, "lr_scheduler": lr_scheduler1 }, { "optimizer": opt2, "lr_scheduler": lr_scheduler2, } model = CustomBoringModel() model.training_epoch_end = None max_epochs = 3 limit_train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, enable_checkpointing=False, logger=False, max_epochs=max_epochs, limit_train_batches=limit_train_batches, limit_val_batches=0, ) with patch.object(CustomEpochScheduler, "step") as mock_method_epoch, patch.object( torch.optim.lr_scheduler.StepLR, "step") as mock_method_step: trainer.fit(model) assert mock_method_epoch.mock_calls == [ call(epoch=e) for e in range(max_epochs) ] # first step is called by PyTorch _LRScheduler assert mock_method_step.call_count == max_epochs * limit_train_batches + 1
def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad_make_optimizer_step( tmpdir): """ Test lightning optimize works with optimizer_zero_grad overrides and make_optimizer_step in automatic_optimization """ try: with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \ patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: class TestModel(BoringModel): def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss = self.loss(batch, output) return {"loss": loss} def training_epoch_end(self, outputs): outputs = sum(outputs, []) torch.stack([x["loss"] for x in outputs]).mean() def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): if optimizer_idx == 0: if batch_idx % 2 == 0: optimizer.zero_grad() if optimizer_idx == 1: if batch_idx % 5 == 0: optimizer.zero_grad() def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs, ): assert optimizer_closure.__name__ == "train_step_and_backward_closure" if optimizer_idx == 0: optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 3 == 0) return optimizer.step(closure=optimizer_closure) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) return [optimizer_1, optimizer_2], [lr_scheduler] model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=20, limit_val_batches=1, max_epochs=1, weights_summary=None, ) trainer.fit(model) assert adam_zero_grad.call_count == 4 assert sgd_zero_grad.call_count == 10 except MisconfigurationException as e: assert "When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed" in str( e)
def test_trainer_callback_system(tmpdir): """Test the callback system.""" hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) def _check_args(trainer, pl_module): assert isinstance(trainer, Trainer) assert isinstance(pl_module, LightningModule) class TestCallback(Callback): def __init__(self): super().__init__() self.setup_called = False self.teardown_called = False self.on_init_start_called = False self.on_init_end_called = False self.on_fit_start_called = False self.on_fit_end_called = False self.on_sanity_check_start_called = False self.on_sanity_check_end_called = False self.on_epoch_start_called = False self.on_epoch_end_called = False self.on_batch_start_called = False self.on_batch_end_called = False self.on_train_batch_start_called = False self.on_train_batch_end_called = False self.on_validation_batch_start_called = False self.on_validation_batch_end_called = False self.on_test_batch_start_called = False self.on_test_batch_end_called = False self.on_train_start_called = False self.on_train_end_called = False self.on_pretrain_routine_start_called = False self.on_pretrain_routine_end_called = False self.on_validation_start_called = False self.on_validation_end_called = False self.on_test_start_called = False self.on_test_end_called = False self.on_after_backward_called = False self.on_before_zero_grad_called = False def setup(self, trainer, pl_module, stage: str): assert isinstance(trainer, Trainer) self.setup_called = True def teardown(self, trainer, pl_module, step: str): assert isinstance(trainer, Trainer) self.teardown_called = True def on_init_start(self, trainer): assert isinstance(trainer, Trainer) self.on_init_start_called = True def on_init_end(self, trainer): assert isinstance(trainer, Trainer) self.on_init_end_called = True def on_fit_start(self, trainer, pl_module): assert isinstance(trainer, Trainer) self.on_fit_start_called = True def on_fit_end(self, trainer, pl_module): assert isinstance(trainer, Trainer) self.on_fit_end_called = True def on_sanity_check_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_sanity_check_start_called = True def on_sanity_check_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_sanity_check_end_called = True def on_epoch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_epoch_start_called = True def on_epoch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_epoch_end_called = True def on_batch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_start_called = True def on_batch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_end_called = True def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_train_batch_start_called = True def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_train_batch_end_called = True def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_validation_batch_start_called = True def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_validation_batch_end_called = True def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_test_batch_start_called = True def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_test_batch_end_called = True def on_train_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_train_start_called = True def on_train_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_train_end_called = True def on_pretrain_routine_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_pretrain_routine_start_called = True def on_pretrain_routine_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_pretrain_routine_end_called = True def on_validation_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_validation_start_called = True def on_validation_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_validation_end_called = True def on_test_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_test_start_called = True def on_test_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_test_end_called = True def on_after_backward(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_after_backward_called = True def on_before_zero_grad(self, trainer, pl_module, optimizer): _check_args(trainer, pl_module) self.on_before_zero_grad_called = True test_callback = TestCallback() trainer_options = dict( default_root_dir=tmpdir, callbacks=[test_callback], max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, progress_bar_refresh_rate=0, ) assert not test_callback.setup_called assert not test_callback.teardown_called assert not test_callback.on_init_start_called assert not test_callback.on_init_end_called assert not test_callback.on_fit_start_called assert not test_callback.on_fit_end_called assert not test_callback.on_sanity_check_start_called assert not test_callback.on_sanity_check_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called assert not test_callback.on_train_batch_start_called assert not test_callback.on_train_batch_end_called assert not test_callback.on_validation_batch_start_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_test_batch_start_called assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_pretrain_routine_start_called assert not test_callback.on_pretrain_routine_end_called assert not test_callback.on_validation_start_called assert not test_callback.on_validation_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called assert not test_callback.on_after_backward_called assert not test_callback.on_before_zero_grad_called # fit model trainer = Trainer(**trainer_options) assert trainer.callbacks[0] == test_callback assert test_callback.on_init_start_called assert test_callback.on_init_end_called assert not test_callback.setup_called assert not test_callback.teardown_called assert not test_callback.on_fit_start_called assert not test_callback.on_fit_end_called assert not test_callback.on_sanity_check_start_called assert not test_callback.on_sanity_check_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called assert not test_callback.on_train_batch_start_called assert not test_callback.on_train_batch_end_called assert not test_callback.on_validation_batch_start_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_test_batch_start_called assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_pretrain_routine_start_called assert not test_callback.on_pretrain_routine_end_called assert not test_callback.on_validation_start_called assert not test_callback.on_validation_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called assert not test_callback.on_after_backward_called assert not test_callback.on_before_zero_grad_called trainer.fit(model) assert test_callback.setup_called assert test_callback.teardown_called assert test_callback.on_init_start_called assert test_callback.on_init_end_called assert test_callback.on_fit_start_called assert test_callback.on_fit_end_called assert test_callback.on_sanity_check_start_called assert test_callback.on_sanity_check_end_called assert test_callback.on_epoch_start_called assert test_callback.on_epoch_start_called assert test_callback.on_batch_start_called assert test_callback.on_batch_end_called assert test_callback.on_train_batch_start_called assert test_callback.on_train_batch_end_called assert test_callback.on_validation_batch_start_called assert test_callback.on_validation_batch_end_called assert test_callback.on_train_start_called assert test_callback.on_train_end_called assert test_callback.on_pretrain_routine_start_called assert test_callback.on_pretrain_routine_end_called assert test_callback.on_validation_start_called assert test_callback.on_validation_end_called assert not test_callback.on_test_batch_start_called assert not test_callback.on_test_batch_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called assert test_callback.on_after_backward_called assert test_callback.on_before_zero_grad_called # reset setup teardown callback test_callback.teardown_called = False test_callback.setup_called = False test_callback = TestCallback() trainer_options.update(callbacks=[test_callback]) trainer = Trainer(**trainer_options) trainer.test(model) assert test_callback.setup_called assert test_callback.teardown_called assert test_callback.on_test_batch_start_called assert test_callback.on_test_batch_end_called assert test_callback.on_test_start_called assert test_callback.on_test_end_called assert not test_callback.on_validation_start_called assert not test_callback.on_validation_end_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_validation_batch_start_called assert not test_callback.on_after_backward_called assert not test_callback.on_before_zero_grad_called
def result_collection_reload(accelerator="auto", devices=1, **kwargs): """This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault Tolerant Training is correct.""" class CustomException(Exception): pass class ExtendedBoringModel(BoringModel): def __init__(self): super().__init__() self.breaking_batch_idx = 3 self.has_validated_sum = False self.dummy_metric = DummyMeanMetric() @property def results(self): return self.trainer.fit_loop._results def training_step(self, batch, batch_idx): # In the training step, we will accumulate metrics using batch_idx from 0 to 4 # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size` # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`. # However, below we will simulate a failure on `batch_idx=3`. if self.trainer.fit_loop.restarting: self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) value = self.results["training_step.tracking_metric"].value value_2 = self.results["training_step.tracking"].value # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks. # The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]` shift = 0 if devices == 2: shift = 3 if self.trainer.is_global_zero else -3 expected = sum(range(batch_idx + 1)) + shift assert expected == value == value_2 else: if batch_idx == self.breaking_batch_idx: # simulate failure mid epoch raise CustomException self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) value = self.results["training_step.tracking"].value assert value == sum(range(batch_idx + 1)) value = self.results["training_step.tracking_2"] assert value == sum(range(batch_idx + 1)) return super().training_step(batch, batch_idx) def on_train_epoch_end(self) -> None: if self.trainer.fit_loop.restarting: total = sum(range(5)) * devices metrics = self.results.metrics(on_step=False) assert self.results["training_step.tracking"].value == total assert metrics["callback"][ "tracking"] == self.dummy_metric.compute() == 2 assert self.results["training_step.tracking_2"].value == total assert metrics["callback"][ "tracking_2"] == self.dummy_metric.compute() == 2 self.has_validated_sum = True model = ExtendedBoringModel() trainer_kwargs = { "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0, "accelerator": accelerator, "devices": devices, } trainer_kwargs.update(kwargs) trainer = Trainer(**trainer_kwargs) with suppress(CustomException): trainer.fit(model) assert not model.has_validated_sum tmpdir = (trainer.strategy.broadcast(trainer_kwargs["default_root_dir"], 0) if devices >= 2 else trainer_kwargs["default_root_dir"]) ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") trainer = Trainer(**trainer_kwargs) trainer.fit(model, ckpt_path=ckpt_path) assert model.has_validated_sum
"max") profiler = pl.profiler.AdvancedProfiler(report_file) callbacks = [SaveModelCallback(), checkpoint_callback, early_stopping] tb_logger = pl_loggers.TensorBoardLogger(save_dir=logdir, name=exp_id, version=args.seed, log_graph=True) epochs = args.epochs trainer = Trainer( gpus=1 if device == "cuda" else 0, profiler=profiler, callbacks=callbacks, min_epochs=epochs, max_epochs=epochs + 5, progress_bar_refresh_rate=20, weights_summary="top", benchmark=True, logger=tb_logger, ) trainer.fit(model, train_loader, valid_loader) final_checkpoint_file = os.path.join(exp_log_dir, "final_epoch.pth") torch.save(model.state_dict(), final_checkpoint_file) model.eval() model = model.to(device) df = validate(model, valid_loader, "valid", exp_log_dir) df = validate(model, test_loader, "test", exp_log_dir)
def train(param): if not isinstance(param, dict): args = vars(param) else: args = param framework = get_class_by_name('conditioned_separation', args['model']) if args['spec_type'] != 'magnitude': args['input_channels'] = 4 if args['resume_from_checkpoint'] is None: if args['seed'] is not None: seed_everything(args['seed']) model = framework(**args) if args['last_activation'] != 'identity' and args[ 'spec_est_mode'] != 'masking': warn( 'Please check if you really want to use a mapping-based spectrogram estimation method ' 'with a final activation function. ') ########################################################## # -- checkpoint ckpt_path = Path(args['ckpt_root_path']) mkdir_if_not_exists(ckpt_path) ckpt_path = ckpt_path.joinpath(args['model']) mkdir_if_not_exists(ckpt_path) run_id = args['run_id'] ckpt_path = ckpt_path.joinpath(run_id) mkdir_if_not_exists(ckpt_path) save_top_k = args['save_top_k'] checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, save_top_k=save_top_k, verbose=False, monitor='val_loss', save_last=False, save_weights_only=args['save_weights_only']) args['checkpoint_callback'] = checkpoint_callback # -- early stop patience = args['patience'] early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=0.0, patience=patience, verbose=False) args['early_stop_callback'] = early_stop_callback if args['resume_from_checkpoint'] is not None: run_id = run_id + "_resume_" + args['resume_from_checkpoint'] args['resume_from_checkpoint'] = Path(args['ckpt_root_path']).joinpath( args['model']).joinpath(args['run_id']).joinpath( args['resume_from_checkpoint']) args['resume_from_checkpoint'] = str(args['resume_from_checkpoint']) # -- logger setting log = args['log'] if log == 'False': args['logger'] = False elif log == 'wandb': args['logger'] = WandbLogger(project='lasaft', tags=args['model'], offline=False, id=run_id) args['logger'].log_hyperparams(model.hparams) args['logger'].watch(model, log='all') elif log == 'tensorboard': raise NotImplementedError else: args['logger'] = True # default default_save_path = 'etc/lightning_logs' mkdir_if_not_exists(default_save_path) valid_kwargs = inspect.signature(Trainer.__init__).parameters trainer_kwargs = dict( (name, args[name]) for name in valid_kwargs if name in args) # DATASET ########################################################## data_provider = DataProvider(**args) ########################################################## # Trainer Definition # Trainer trainer = Trainer(**trainer_kwargs) n_fft, hop_length, num_frame = args['n_fft'], args['hop_length'], args[ 'num_frame'] train_data_loader = data_provider.get_train_dataloader( n_fft, hop_length, num_frame) valid_data_loader = data_provider.get_valid_dataloader( n_fft, hop_length, num_frame) for key in sorted(args.keys()): print('{}:{}'.format(key, args[key])) if args['auto_lr_find']: lr_finder = trainer.lr_find(model, train_data_loader, valid_data_loader, early_stop_threshold=None) print(lr_finder.results) # torch.save(lr_finder.results, 'lr_result.cache') new_lr = lr_finder.suggestion() print('new_lr_suggestion:', new_lr) return 0 print(model) trainer.fit(model, train_data_loader, valid_data_loader) return None
def test_gradient_accumulation_scheduling(tmpdir): """ Test grad accumulation by the freq of optimizer updates """ tutils.reset_seed() # test incorrect configs with pytest.raises(IndexError): assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6}) assert Trainer(accumulate_grad_batches={-2: 3}) with pytest.raises(TypeError): assert Trainer(accumulate_grad_batches={}) assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]]) assert Trainer(accumulate_grad_batches={1: 2, 3.: 4}) assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5}) # test optimizer call freq matches scheduler def _optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): # only test the first 12 batches in epoch if batch_idx < 12: if epoch == 0: # reset counter when starting epoch if batch_idx == 0: self.prev_called_batch_idx = 0 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 1 assert batch_idx == self.prev_called_batch_idx self.prev_called_batch_idx += 1 elif 1 <= epoch <= 2: # reset counter when starting epoch if batch_idx == 1: self.prev_called_batch_idx = 1 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 2 assert batch_idx == self.prev_called_batch_idx self.prev_called_batch_idx += 2 else: if batch_idx == 3: self.prev_called_batch_idx = 3 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 4 assert batch_idx == self.prev_called_batch_idx self.prev_called_batch_idx += 3 optimizer.step() # clear gradients optimizer.zero_grad() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) schedule = {1: 2, 3: 4} trainer = Trainer(accumulate_grad_batches=schedule, train_percent_check=0.1, val_percent_check=0.1, max_epochs=2, default_root_dir=tmpdir) # for the test trainer.optimizer_step = _optimizer_step model.prev_called_batch_idx = 0 trainer.fit(model)
def test_training_loop_hook_call_order(tmpdir): """Tests that hooks / methods called in the training loop are in the correct order as detailed in the docs: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks""" class HookedModel(BoringModel): def __init__(self): super().__init__() self.called = [] def on_epoch_start(self): self.called.append("on_epoch_start") super().on_epoch_start() def on_train_epoch_start(self): self.called.append("on_train_epoch_start") super().on_train_epoch_start() def on_train_batch_start(self, batch, batch_idx, dataloader_idx): self.called.append("on_train_batch_start") super().on_train_batch_start(batch, batch_idx, dataloader_idx) def training_step(self, batch, batch_idx): self.called.append("training_step") return super().training_step(batch, batch_idx) def on_before_zero_grad(self, optimizer): self.called.append("on_before_zero_grad") super().on_before_zero_grad(optimizer) def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): self.called.append("optimizer_zero_grad") super().optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx) def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs): self.called.append("backward") super().backward(loss, optimizer, optimizer_idx, *args, **kwargs) def on_after_backward(self): self.called.append("on_after_backward") super().on_after_backward() def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs, ): super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs) self.called.append("optimizer_step" ) # append after as closure calls other methods def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.called.append("on_train_batch_end") super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) def training_epoch_end(self, outputs): self.called.append("training_epoch_end") super().training_epoch_end(outputs) def on_train_epoch_end(self, outputs): self.called.append("on_train_epoch_end") super().on_train_epoch_end(outputs) def on_epoch_end(self): self.called.append("on_epoch_end") super().on_epoch_end() model = HookedModel() # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, limit_train_batches=1, limit_test_batches=1, progress_bar_refresh_rate=0, weights_summary=None, ) assert model.called == [] trainer.fit(model) expected = [ 'on_epoch_start', # validation 'on_epoch_end', 'on_epoch_start', # training 'on_train_epoch_start', 'on_train_batch_start', 'training_step', 'on_before_zero_grad', 'optimizer_zero_grad', 'backward', 'on_after_backward', 'optimizer_step', 'on_train_batch_end', 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end', 'on_epoch_start', # validation 'on_epoch_end' ] assert model.called == expected
def test_resume_from_checkpoint_epoch_restored(tmpdir): """Verify resuming from checkpoint runs the right number of epochs""" import types tutils.reset_seed() hparams = tutils.get_default_hparams() def _new_model(): # Create a model that tracks epochs and batches seen model = LightningTestModel(hparams) model.num_epochs_seen = 0 model.num_batches_seen = 0 def increment_epoch(self): self.num_epochs_seen += 1 def increment_batch(self, _): self.num_batches_seen += 1 # Bind the increment_epoch function on_epoch_end so that the # model keeps track of the number of epochs it has seen. model.on_epoch_end = types.MethodType(increment_epoch, model) model.on_batch_start = types.MethodType(increment_batch, model) return model model = _new_model() trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, train_percent_check=0.65, val_percent_check=1, checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), logger=False, default_root_dir=tmpdir, early_stop_callback=False, val_check_interval=1., ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model) training_batches = trainer.num_training_batches assert model.num_epochs_seen == 2 assert model.num_batches_seen == training_batches * 2 # Other checkpoints can be uncommented if/when resuming mid-epoch is supported checkpoints = sorted( glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) for check in checkpoints: next_model = _new_model() state = torch.load(check) # Resume training trainer_options['max_epochs'] = 2 new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check) new_trainer.fit(next_model) assert state[ 'global_step'] + next_model.num_batches_seen == training_batches * trainer_options[ 'max_epochs']
return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9) def train_dataloader(self, *args, **kwargs): ds = CIFAR10(root='./data', train=True, download=True, transform=self.transform) return torch.utils.data.DataLoader(ds, batch_size=32, shuffle=True, num_workers=0) def val_dataloader(self, *args, **kwargs): ds = CIFAR10(root='./data', train=False, download=True, transform=self.transform) return torch.utils.data.DataLoader(ds, batch_size=32, shuffle=True, num_workers=0) if __name__ == "__main__": model = LiftModel(models.resnet50(pretrained=True)) trainer = Trainer(max_epochs=10, gpus=2, profiler="pytorch", accelerator="ddp") trainer.fit(model) trainer.validate(model)
def train(config): # ====================================================== # EXPERIMENT SETUP # ====================================================== from pytorch_lightning import seed_everything # Seed seed_everything(config.seed) # DATASET SETUP print("======================================================") print("SETTING UP DATASET") print("======================================================") from ml4floods.models.dataset_setup import get_dataset dataset = get_dataset(config.data_params) # MODEL SETUP print("======================================================") print("SETTING UP MODEL") print("======================================================") from ml4floods.models.model_setup import get_model config.model_params.test = False config.model_params.train = True model = get_model(config.model_params) # LOGGING SETUP print("======================================================") print("SETTING UP LOGGERS") print("======================================================") import wandb from pytorch_lightning.loggers import WandbLogger wandb_logger = WandbLogger( name=config.experiment_name, project=config.wandb_project, entity=config.wandb_entity, # save_dir=f"{config.model_params.model_folder}/{config.experiment_name}" ) # CHECKPOINTING SETUP print("======================================================") print("SETTING UP CHECKPOINTING") print("======================================================") from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping experiment_path = f"{config.model_params.model_folder}/{config.experiment_name}" checkpoint_path = f"{experiment_path}/checkpoint" checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_path, save_top_k=True, verbose=True, monitor=config.model_params.hyperparameters.metric_monitor, mode='min', prefix='') early_stop_callback = EarlyStopping( monitor=config.model_params.hyperparameters.metric_monitor, patience=4, strict=False, verbose=False, mode='min') callbacks = [checkpoint_callback, early_stop_callback] # TRAINING SETUP print("======================================================") print("START TRAINING") print("======================================================") from pytorch_lightning import Trainer trainer = Trainer( fast_dev_run=False, logger=wandb_logger, callbacks=callbacks, default_root_dir= f"{config.model_params.model_folder}/{config.experiment_name}", accumulate_grad_batches=1, gradient_clip_val=0.0, auto_lr_find=False, benchmark=False, distributed_backend=None, gpus=config.gpus if config.gpus != '' else None, max_epochs=config.model_params.hyperparameters.max_epochs, check_val_every_n_epoch=config.model_params.hyperparameters.val_every, log_gpu_memory=None, resume_from_checkpoint=checkpoint_path if config.resume_from_checkpoint else None) trainer.fit(model, dataset) # ====================================================== # SAVING SETUP # ====================================================== print("======================================================") print("FINISHED TRAINING, SAVING MODEL") print("======================================================") from pytorch_lightning.utilities.cloud_io import atomic_save atomic_save(model.state_dict(), f"{experiment_path}/model.pt") torch.save(model.state_dict(), os.path.join(wandb_logger.save_dir, 'model.pt')) wandb.save(os.path.join(wandb_logger.save_dir, 'model.pt')) wandb.finish() # Save cofig file in experiment_path config_file_path = f"{experiment_path}/config.json" save_json(config, config_file_path) return 1
def test_dp_resume(): """ Make sure DP continues training correctly :return: """ if not can_run_gpu_test(): return hparams = get_hparams() model = LightningTestModel(hparams) trainer_options = dict( show_progress_bar=True, max_nb_epochs=2, gpus=2, distributed_backend='dp', ) save_dir = init_save_dir() # get logger logger = get_test_tube_logger(debug=False) logger.log_hyperparams(hparams) # exp file to get weights checkpoint = ModelCheckpoint(save_dir) # add these to the trainer options trainer_options['logger'] = logger trainer_options['checkpoint_callback'] = checkpoint # fit model trainer = Trainer(**trainer_options) trainer.is_slurm_managing_tasks = True result = trainer.fit(model) # track epoch before saving real_global_epoch = trainer.current_epoch # correct result and ok accuracy assert result == 1, 'amp + dp model failed to complete' # --------------------------- # HPC LOAD/SAVE # --------------------------- # save trainer.hpc_save(save_dir, logger) # init new trainer new_logger = get_test_tube_logger(version=logger.version) trainer_options['logger'] = new_logger trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir) trainer_options['train_percent_check'] = 0.2 trainer_options['val_percent_check'] = 0.2 trainer_options['max_nb_epochs'] = 1 new_trainer = Trainer(**trainer_options) # set the epoch start hook so we can predict before the model does the full training def assert_good_acc(): assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() _ = [run_prediction(dataloader, dp_model, dp=True) for dataloader in trainer.val_dataloader] # new model model = LightningTestModel(hparams) model.on_sanity_check_start = assert_good_acc # fit new model which should load hpc weights new_trainer.fit(model) # test freeze on gpu model.freeze() model.unfreeze() clear_save_dir()
def test_fit_can_fail_during_validation(train_datasets, val_datasets, val_check_interval, tmpdir): size, n_batches = 2, 4 stop_batch = 1 n_val_dataloaders = len(val_datasets) stop_dataloader = n_val_dataloaders - 1 class TestModel(LightningModule): def __init__(self, should_fail): super().__init__() self.layer = torch.nn.Linear(size, 2) self.should_fail = should_fail def step(self, batch): return sum(self.layer(b).sum() for b in batch) def training_step(self, batch, batch_idx): return self.step(batch) def validation_step(self, batch, batch_idx, dataloader_idx=0): if self.should_fail and dataloader_idx == stop_dataloader and batch_idx == stop_batch: raise CustomException return self.step(batch) def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) def train_dataloader(self): return [DataLoader(cls(size, n_batches)) for cls in train_datasets] def val_dataloader(self): return [DataLoader(cls(size, n_batches)) for cls in val_datasets] model = TestModel(False) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval, num_sanity_val_steps=0, enable_progress_bar=False, ) trainer.fit(model) ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") assert not os.path.exists(ckpt_path), "Shouldn't have failed" state_dict = trainer.fit_loop.state_dict() expected_global_step = trainer.global_step assert state_dict["epoch_loop.batch_progress"] == { "total": { "ready": n_batches, "started": n_batches, "processed": n_batches, "completed": n_batches }, "current": { "ready": n_batches, "started": n_batches, "processed": n_batches, "completed": n_batches }, "is_last_batch": True, } val_per_epoch = int(1 // val_check_interval) assert state_dict["epoch_loop.val_loop.dataloader_progress"] == { "total": { "ready": n_val_dataloaders * val_per_epoch, "completed": n_val_dataloaders * val_per_epoch }, "current": { "ready": n_val_dataloaders, "completed": n_val_dataloaders }, } assert state_dict["epoch_loop.val_loop.epoch_loop.batch_progress"] == { "total": { "ready": n_val_dataloaders * val_per_epoch * n_batches, "started": n_val_dataloaders * val_per_epoch * n_batches, "processed": n_val_dataloaders * val_per_epoch * n_batches, "completed": n_val_dataloaders * val_per_epoch * n_batches, }, "current": { "ready": n_batches, "completed": n_batches, "started": n_batches, "processed": n_batches }, "is_last_batch": True, } model = TestModel(True) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval, num_sanity_val_steps=0, enable_progress_bar=False, ) with pytest.raises(CustomException): # will stop during validation trainer.fit(model) assert os.path.exists(ckpt_path) checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"] per_val_train_batches = int(n_batches * val_check_interval) assert checkpoint["epoch_loop.batch_progress"] == { "total": { "ready": per_val_train_batches, "started": per_val_train_batches, "processed": per_val_train_batches, "completed": per_val_train_batches, }, "current": { "ready": per_val_train_batches, "started": per_val_train_batches, "processed": per_val_train_batches, "completed": per_val_train_batches, }, "is_last_batch": val_check_interval == 1, } val_batch_progress = "epoch_loop.val_loop.epoch_loop.batch_progress" # "nb_": non-breaking nb_total_val_batch = stop_dataloader * n_batches assert checkpoint[val_batch_progress] == { "total": { "ready": nb_total_val_batch + stop_batch + 1, "started": nb_total_val_batch + stop_batch + 1, "processed": nb_total_val_batch + stop_batch, "completed": nb_total_val_batch + stop_batch, }, "current": { "ready": stop_batch + 1, "started": stop_batch + 1, "processed": stop_batch, "completed": stop_batch, }, "is_last_batch": False, } model = TestModel(False) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval, enable_progress_bar=False, ) trainer.fit(model, ckpt_path=ckpt_path) assert trainer.global_step == expected_global_step state_dict_after_restart = trainer.fit_loop.state_dict() # should get the same values as in the run that did not fail # totals are increased by 1 (the failed batch which never completed) expected = state_dict.copy() assert state_dict_after_restart["epoch_loop.batch_progress"] == expected[ "epoch_loop.batch_progress"] val_dl_progress = "epoch_loop.val_loop.dataloader_progress" expected[val_dl_progress]["total"]["ready"] += 1 assert state_dict_after_restart[val_dl_progress] == expected[ val_dl_progress] expected[val_batch_progress]["total"]["ready"] += 1 expected[val_batch_progress]["total"]["started"] += 1 assert state_dict_after_restart[val_batch_progress] == expected[ val_batch_progress]