def test_v1_5_0_model_checkpoint_save_checkpoint(): model_ckpt = ModelCheckpoint() trainer = Trainer() trainer.save_checkpoint = lambda *_, **__: None with pytest.deprecated_call( match="ModelCheckpoint.save_checkpoint` signature has changed"): model_ckpt.save_checkpoint(trainer, object())
def test_deprecated_mc_save_checkpoint(): mc = ModelCheckpoint() trainer = Trainer() with mock.patch.object(trainer, "save_checkpoint"), pytest.deprecated_call( match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6" ): mc.save_checkpoint(trainer)
def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): """Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run.""" class FastDevRunModel(BoringModel): def __init__(self): super().__init__() self.training_step_call_count = 0 self.training_epoch_end_call_count = 0 self.validation_step_call_count = 0 self.validation_epoch_end_call_count = 0 self.test_step_call_count = 0 def training_step(self, batch, batch_idx): self.log("some_metric", torch.tensor(7.0)) self.logger.experiment.dummy_log("some_distribution", torch.randn(7) + batch_idx) self.training_step_call_count += 1 return super().training_step(batch, batch_idx) def training_epoch_end(self, outputs): self.training_epoch_end_call_count += 1 super().training_epoch_end(outputs) def validation_step(self, batch, batch_idx): self.validation_step_call_count += 1 return super().validation_step(batch, batch_idx) def validation_epoch_end(self, outputs): self.validation_epoch_end_call_count += 1 super().validation_epoch_end(outputs) def test_step(self, batch, batch_idx): self.test_step_call_count += 1 return super().test_step(batch, batch_idx) checkpoint_callback = ModelCheckpoint() checkpoint_callback.save_checkpoint = Mock() early_stopping_callback = EarlyStopping(monitor="foo") early_stopping_callback._evaluate_stopping_criteria = Mock() trainer_config = dict( default_root_dir=tmpdir, fast_dev_run=fast_dev_run, val_check_interval=2, logger=True, log_every_n_steps=1, callbacks=[checkpoint_callback, early_stopping_callback], ) def _make_fast_dev_run_assertions(trainer, model): # check the call count for train/val/test step/epoch assert model.training_step_call_count == fast_dev_run assert model.training_epoch_end_call_count == 1 assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1 assert model.test_step_call_count == fast_dev_run # check trainer arguments assert trainer.max_steps == fast_dev_run assert trainer.num_sanity_val_steps == 0 assert trainer.max_epochs == 1 assert trainer.val_check_interval == 1.0 assert trainer.check_val_every_n_epoch == 1 # there should be no logger with fast_dev_run assert isinstance(trainer.logger, DummyLogger) # checkpoint callback should not have been called with fast_dev_run assert trainer.checkpoint_callback == checkpoint_callback checkpoint_callback.save_checkpoint.assert_not_called() assert not os.path.exists(checkpoint_callback.dirpath) # early stopping should not have been called with fast_dev_run assert trainer.early_stopping_callback == early_stopping_callback early_stopping_callback._evaluate_stopping_criteria.assert_not_called() train_val_step_model = FastDevRunModel() trainer = Trainer(**trainer_config) trainer.fit(train_val_step_model) trainer.test(train_val_step_model) assert trainer.state.finished, f"Training failed with {trainer.state}" _make_fast_dev_run_assertions(trainer, train_val_step_model) # ----------------------- # also called once with no val step # ----------------------- train_step_only_model = FastDevRunModel() train_step_only_model.validation_step = None trainer = Trainer(**trainer_config) trainer.fit(train_step_only_model) trainer.test(train_step_only_model) assert trainer.state.finished, f"Training failed with {trainer.state}" _make_fast_dev_run_assertions(trainer, train_step_only_model)
def train_task(init, close, exp_cfg_path, env_cfg_path, task_nr, logger_pass=None): seed_everything(42) local_rank = int(os.environ.get('LOCAL_RANK', 0)) if local_rank != 0 or not init: print(init, local_rank) rm = exp_cfg_path.find('cfg/exp/') + len('cfg/exp/') exp_cfg_path = os.path.join(exp_cfg_path[:rm], 'tmp/', exp_cfg_path[rm:]) exp = load_yaml(exp_cfg_path) env = load_yaml(env_cfg_path) if local_rank == 0 and init: # Set in name the correct model path if exp.get('timestamp', True): timestamp = datetime.datetime.now().replace( microsecond=0).isoformat() model_path = os.path.join(env['base'], exp['name']) p = model_path.split('/') model_path = os.path.join('/', *p[:-1], str(timestamp) + '_' + p[-1]) else: model_path = os.path.join(env['base'], exp['name']) try: shutil.rmtree(model_path) except: pass # Create the directory if not os.path.exists(model_path): try: os.makedirs(model_path) except: print("Failed generating network run folder") else: print("Network run folder already exits") # Only copy config files for the main ddp-task exp_cfg_fn = os.path.split(exp_cfg_path)[-1] env_cfg_fn = os.path.split(env_cfg_path)[-1] print(f'Copy {env_cfg_path} to {model_path}/{exp_cfg_fn}') shutil.copy(exp_cfg_path, f'{model_path}/{exp_cfg_fn}') shutil.copy(env_cfg_path, f'{model_path}/{env_cfg_fn}') exp['name'] = model_path else: # the correct model path has already been written to the yaml file. model_path = os.path.join(exp['name'], f'rank_{local_rank}_{task_nr}') # Create the directory if not os.path.exists(model_path): try: os.makedirs(model_path) except: pass # if local_rank == 0 and env['workstation'] == False: # cm = open(os.path.join(model_path, f'info{local_rank}_{task_nr}.log'), 'w') # else: # cm = nullcontext() # with cm as f: # if local_rank == 0 and env['workstation'] == False: # cm2 = redirect_stdout(f) # else: # cm2 = nullcontext() # with cm2: # # Setup logger for each ddp-task # logging.getLogger("lightning").setLevel(logging.DEBUG) # logger = logging.getLogger("lightning") # fh = logging.FileHandler( , 'a') # logger.addHandler(fh) # Copy Dataset from Scratch to Nodes SSD if env['workstation'] == False: # use proxy hack for neptunai !!! NeptuneLogger._create_or_get_experiment = _create_or_get_experiment2 # move data to ssd if exp['move_datasets'][0]['env_var'] != 'none': for dataset in exp['move_datasets']: scratchdir = os.getenv('TMPDIR') env_var = dataset['env_var'] tar = os.path.join(env[env_var], f'{env_var}.tar') name = (tar.split('/')[-1]).split('.')[0] if not os.path.exists( os.path.join(scratchdir, dataset['env_var'])): try: cmd = f"tar -xvf {tar} -C $TMPDIR >/dev/null 2>&1" st = time.time() print(f'Start moveing dataset-{env_var}: {cmd}') os.system(cmd) env[env_var] = str(os.path.join(scratchdir, name)) print( f'Finished moveing dataset-{env_var} in {time.time()-st}s' ) except: rank_zero_warn('ENV Var' + env_var) env[env_var] = str(os.path.join(scratchdir, name)) rank_zero_warn('Copying data failed') else: env[env_var] = str(os.path.join(scratchdir, name)) else: env['mlhypersim'] = str( os.path.join(env['mlhypersim'], 'mlhypersim')) if (exp['trainer']).get('gpus', -1): nr = torch.cuda.device_count() exp['trainer']['gpus'] = nr print(f'Set GPU Count for Trainer to {nr}!') model = Network(exp=exp, env=env) lr_monitor = LearningRateMonitor(**exp['lr_monitor']['cfg']) if exp['cb_early_stopping']['active']: early_stop_callback = EarlyStopping(**exp['cb_early_stopping']['cfg']) cb_ls = [early_stop_callback, lr_monitor] else: cb_ls = [lr_monitor] tses = TaskSpecificEarlyStopping( nr_tasks=exp['task_generator']['total_tasks'], **exp['task_specific_early_stopping']) cb_ls.append(tses) if local_rank == 0: for i in range(exp['task_generator']['total_tasks']): if i == task_nr: m = '/'.join( [a for a in model_path.split('/') if a.find('rank') == -1]) dic = copy.deepcopy(exp['cb_checkpoint']['cfg']) # try: # if len(exp['cb_checkpoint'].get('nameing',[])) > 0: # #filepath += '-{task_name:10s}' # for m in exp['cb_checkpoint']['nameing']: # filepath += '-{'+ m + ':.2f}' # except: # pass # dic['monitor'] += str(i) checkpoint_callback = ModelCheckpoint( dirpath=m, filename='task' + str(i) + '-{epoch:02d}--{step:06d}', **dic) cb_ls.append(checkpoint_callback) params = log_important_params(exp) if env['workstation']: t1 = 'workstation' else: t1 = 'leonhard' # if local_rank == 0: cwd = os.getcwd() files = [ str(p).replace(cwd + '/', '') for p in Path(cwd).rglob('*.py') if str(p).find('vscode') == -1 ] files.append(exp_cfg_path) files.append(env_cfg_path) if not exp.get('offline_mode', False): # if exp.get('experiment_id',-1) == -1: #create new experiment_id and write back if logger_pass is None: logger = NeptuneLogger( api_key=os.environ["NEPTUNE_API_TOKEN"], project_name="jonasfrey96/asl", experiment_name=exp['name'].split('/')[-2] + "_" + exp['name'].split('/')[-1], # Optional, params=params, # Optional, tags=[ t1, exp['name'].split('/')[-2], exp['name'].split('/')[-1] ] + exp["tag_list"], # Optional, close_after_fit=False, offline_mode=exp.get('offline_mode', False), upload_source_files=files, upload_stdout=False, upload_stderr=False) exp['experiment_id'] = logger.experiment.id print('created experiment id' + str(exp['experiment_id'])) else: logger = logger_pass # else: # print('loaded experiment id' + str( exp['experiment_id'])) # TODO # logger = NeptuneLogger( # api_key=os.environ["NEPTUNE_API_TOKEN"], # project_name="jonasfrey96/asl", # experiment_name= exp['name'].split('/')[-2] +"_"+ exp['name'].split('/')[-1], # Optional, # params=params, # Optional, # tags=[t1, exp['name'].split('/')[-2], exp['name'].split('/')[-1]] + exp["tag_list"], # Optional, # close_after_fit = False, # offline_mode = exp.get('offline_mode', False), # upload_source_files=files, # upload_stdout=False, # upload_stderr=False # ) # logger = NeptuneLogger( # api_key=os.environ["NEPTUNE_API_TOKEN"], # project_name="jonasfrey96/asl", # experiment_id=exp.get('experiment_id',-1), # close_after_fit = False, # ) print('Neptune Experiment ID: ' + str(logger.experiment.id) + " TASK NR " + str(task_nr)) else: logger = TensorBoardLogger( save_dir=model_path, name='tensorboard', # Optional, default_hp_metric=params, # Optional, ) # else: # logger = TensorBoardLogger( # save_dir=model_path+'/rank/'+str(local_rank), # name= exp['name'].split('/')[-2] +"_"+ exp['name'].split('/')[-1], # Optional, # ) weight_restore = exp.get('weights_restore', False) checkpoint_load = exp['checkpoint_load'] if local_rank == 0 and init: # write back the exp file with the correct name set to the model_path! # other ddp-task dont need to care about timestamps # also storeing the path to the latest.ckpt that downstream tasks can restore the model state exp['weights_restore_2'] = False exp['checkpoint_restore_2'] = True exp['checkpoint_load_2'] = os.path.join(model_path, 'last.ckpt') rm = exp_cfg_path.find('cfg/exp/') + len('cfg/exp/') exp_cfg_path = os.path.join(exp_cfg_path[:rm], 'tmp/', exp_cfg_path[rm:]) Path(exp_cfg_path).parent.mkdir(parents=True, exist_ok=True) with open(exp_cfg_path, 'w+') as f: yaml.dump(exp, f, default_flow_style=False, sort_keys=False) if not init: # restore model state from previous task. exp['checkpoint_restore'] = exp['checkpoint_restore_2'] exp['checkpoint_load'] = exp['checkpoint_load_2'] exp['weights_restore'] = exp['weights_restore_2'] # Always use advanced profiler if exp['trainer'].get('profiler', False): exp['trainer']['profiler'] = AdvancedProfiler( output_filename=os.path.join(model_path, 'profile.out')) else: exp['trainer']['profiler'] = False # print( exp['trainer'] ) # print(os.environ.get('GLOBAL_RANK')) if exp.get('checkpoint_restore', False): p = os.path.join(env['base'], exp['checkpoint_load']) trainer = Trainer(**exp['trainer'], default_root_dir=model_path, callbacks=cb_ls, resume_from_checkpoint=p, logger=logger) else: trainer = Trainer(**exp['trainer'], default_root_dir=model_path, callbacks=cb_ls, logger=logger) if exp['weights_restore']: # it is not strict since the latent replay buffer is not always available p = os.path.join(env['base'], exp['checkpoint_load']) if os.path.isfile(p): res = model.load_state_dict(torch.load( p, map_location=lambda storage, loc: storage)['state_dict'], strict=False) print('Restoring weights: ' + str(res)) else: raise Exception('Checkpoint not a file') main_visu = MainVisualizer(p_visu=os.path.join(model_path, 'main_visu'), logger=logger, epoch=0, store=True, num_classes=22) tc = TaskCreator(**exp['task_generator'], output_size=exp['model']['input_size']) print(tc) _task_start_training = time.time() _task_start_time = time.time() for idx, out in enumerate(tc): if idx == task_nr: break if True: #for idx, out in enumerate(tc): task, eval_lists = out main_visu.epoch = idx # New Logger print(f'<<<<<<<<<<<< TASK IDX {idx} TASK NAME : ' + task.name + ' >>>>>>>>>>>>>') model._task_name = task.name model._task_count = idx dataloader_train, dataloader_buffer = get_dataloader_train( d_train=task.dataset_train_cfg, env=env, exp=exp) print(str(dataloader_train.dataset)) print(str(dataloader_buffer.dataset)) dataloader_list_test = eval_lists_into_dataloaders(eval_lists, env=env, exp=exp) print(f'<<<<<<<<<<<< All Datasets are loaded and set up >>>>>>>>>>>>>') #Training the model trainer.should_stop = False # print("GLOBAL STEP ", model.global_step) for d in dataloader_list_test: print(str(d.dataset)) if idx < exp['start_at_task']: # trainer.limit_val_batches = 1.0 trainer.limit_train_batches = 1 trainer.max_epochs = 1 trainer.check_val_every_n_epoch = 1 train_res = trainer.fit(model=model, train_dataloader=dataloader_train, val_dataloaders=dataloader_list_test) trainer.max_epochs = exp['trainer']['max_epochs'] trainer.check_val_every_n_epoch = exp['trainer'][ 'check_val_every_n_epoch'] trainer.limit_val_batches = exp['trainer']['limit_val_batches'] trainer.limit_train_batches = exp['trainer']['limit_train_batches'] else: print('Train', dataloader_train) print('Val', dataloader_list_test) train_res = trainer.fit(model=model, train_dataloader=dataloader_train, val_dataloaders=dataloader_list_test) res = trainer.logger_connector.callback_metrics res_store = {} for k in res.keys(): try: res_store[k] = float(res[k]) except: pass base_path = '/'.join( [a for a in model_path.split('/') if a.find('rank') == -1]) with open(f"{base_path}/res{task_nr}.pkl", "wb") as f: pickle.dump(res_store, f) print(f'<<<<<<<<<<<< TASK IDX {idx} TASK NAME : ' + task.name + ' Trained >>>>>>>>>>>>>') if exp.get('buffer', {}).get('fill_after_fit', False): print(f'<<<<<<<<<<<< Performance Test to Get Buffer >>>>>>>>>>>>>') trainer.test(model=model, test_dataloaders=dataloader_buffer) if local_rank == 0: checkpoint_callback.save_checkpoint(trainer, model) print(f'<<<<<<<<<<<< Performance Test DONE >>>>>>>>>>>>>') number_validation_dataloaders = len(dataloader_list_test) if model._rssb_active: # visualize rssb bins, valids = model._rssb.get() fill_status = (bins != 0).sum(axis=1) main_visu.plot_bar(fill_status, x_label='Bin', y_label='Filled', title='Fill Status per Bin', sort=False, reverse=False, tag='Buffer_Fill_Status') plot_from_pkl(main_visu, base_path, task_nr) try: if close: logger.experiment.stop() except: pass