def main(parser): parser.add_argument('-m', '--model', type=str, default='PredNet') parser.add_argument('-d', '--dataset', type=str, default='SchapiroResnetEmbeddingDataset') parser.add_argument('--load_model', action='store_true') parser.add_argument('--ipy', action='store_true') parser.add_argument('--no_graphs', action='store_true') parser.add_argument('--no_test', action='store_true') parser.add_argument('--user', type=str, default='aprashedahmed') parser.add_argument('-p', '--project', type=str, default='sandbox') parser.add_argument('-t', '--tags', nargs='+') parser.add_argument('--no_checkpoints', action='store_true') parser.add_argument('--offline_mode', action='store_true') parser.add_argument('--save_weights_online', action='store_true') parser.add_argument('--test_checkpoints', action='store_true') parser.add_argument('--test_epochs', type=int, default=2) parser.add_argument('--test_n_paths', type=int, default=2) parser.add_argument('--test_online', action='store_true') parser.add_argument('--test_project', type=str, default='') parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('--n_workers', type=int, default=1) parser.add_argument('-e', '--epochs', type=int, default=50) parser.add_argument('--gpus', type=float, default=1) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('-s', '--seed', type=str, default='random') parser.add_argument('-b', '--batch_size', type=int, default=256 + 128) parser.add_argument('--n_val', type=int, default=1) parser.add_argument('--mapping', type=str, default='random') parser.add_argument('--dir_checkpoints', type=str, default=str(index.DIR_CHECKPOINTS)) parser.add_argument('--checkpoint_period', type=float, default=1.0) parser.add_argument('--val_check_interval', type=float, default=1.0) parser.add_argument('--save_top_k', type=float, default=1) parser.add_argument('--early_stop_mode', type=str, default='min') parser.add_argument('--early_stop_patience', type=int, default=10) parser.add_argument('--early_stop_min_delta', type=float, default=0.001) parser.add_argument('--name', type=str, default='') parser.add_argument('--exp_prefix', type=str, default='') parser.add_argument('--exp_suffix', type=str, default='') # Get Model and Dataset specific args temp_args, _ = parser.parse_known_args() # Make sure this is correct if hasattr(datasets, temp_args.dataset): Dataset = getattr(datasets, temp_args.dataset) parser = Dataset.add_dataset_specific_args(parser) else: raise Exception( f'Invalid dataset "{temp_args.dataset}" passed. Check it is ' f'importable: "from prevseg.datasets import {temp_args.dataset}"') # Get temp args now with dataset args added temp_args, _ = parser.parse_known_args() # Check this is correct as well if hasattr(models, temp_args.model): Model = getattr(models, temp_args.model) parser = Model.add_model_specific_args(parser) else: raise Exception( f'Invalid model "{temp_args.model}" passed. Check it is importable:' f' "from prevseg.models import {temp_args.model}"') # Get the parser and turn into an omegaconf hparams = parser.parse_args() # If we are test-running, do a few things differently (scale down dataset, # send to sandbox project, etc.) if hparams.test_run: hparams.epochs = hparams.test_epochs hparams.n_paths = hparams.test_n_paths hparams.name = '_'.join(filter(None, ['test_run', hparams.exp_prefix])) hparams.project = hparams.test_project or 'sandbox' hparams.verbose = True hparams.ipdb = True hparams.no_checkpoints = not hparams.test_checkpoints hparams.offline_mode = not hparams.test_online # Seed is a string to allow for None/random as an input. Make it passable # to pl.seed_everything hparams.seed = None if 'None' in hparams.seed or hparams.seed == 'random' \ else int(hparams.seed) # Get the hostname for book keeping hparams.hostname = socket.gethostname() # Set the seed hparams.seed = pl.seed_everything(hparams.seed) # Turn the string entry for mapping into a dict (that is also a str) if hparams.mapping == 'default': hparams.mapping = const.DEFAULT_MAPPING elif hparams.mapping == 'random': hparams.mapping = str( Dataset.random_mapping(n_pentagons=hparams.n_pentagons)) else: raise ValueError(f'Invalid entry for mapping: {hparams.mapping}') # Set the validation path hparams.val_path = str(const.DEFAULT_PATH) # Create experiment name hparams.name = name_from_hparams(hparams) hparams.exp_name = name_from_hparams(hparams, short=True) if hparams.verbose: print(f'Beginning experiment: "{hparams.name}"') # Neptune Logger logger = NeptuneLogger( project_name=f"{hparams.user}/{hparams.project}", experiment_name=hparams.exp_name, params=vars(hparams), tags=hparams.tags, offline_mode=hparams.offline_mode, upload_source_files=[ str(Path(__file__).resolve()), inspect.getfile(Model), inspect.getfile(Dataset) ], close_after_fit=False, ) if not hparams.load_model: # Checkpoint Call back if hparams.no_checkpoints: checkpoint = False if hparams.verbose: print('\nNot saving any checkpoints.\n', flush=True) else: dir_checkpoints_experiment = (Path(hparams.dir_checkpoints) / hparams.name) if not dir_checkpoints_experiment.exists(): dir_checkpoints_experiment.mkdir(parents=True) checkpoint = pl.callbacks.ModelCheckpoint( filepath=str( dir_checkpoints_experiment / (f'seed={hparams.seed}' + '_{epoch}_{val_loss:.3f}')), verbose=hparams.verbose, save_top_k=hparams.save_top_k, period=hparams.checkpoint_period, ) # Early stopping callback early_stop_callback = pl.callbacks.EarlyStopping( monitor='val_loss', min_delta=hparams.early_stop_min_delta, patience=hparams.early_stop_patience, verbose=hparams.verbose, mode=hparams.early_stop_mode, ) # Define the trainer trainer = pl.Trainer( checkpoint_callback=checkpoint, max_epochs=hparams.epochs, logger=logger, val_check_interval=hparams.val_check_interval, gpus=hparams.gpus, early_stop_callback=early_stop_callback, ) # Verbose messaging if hparams.verbose: now = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S") print(f'\nCurrent time: {now}', flush=True) print(f'\nRunning with following hparams:', flush=True) pprint(vars(hparams)) # Define the model model = Model(hparams) if hparams.verbose: print(f'\nModel being used: \n{model}', flush=True) # Define the datamodule datamodule = datasets.DataModuleConstructor(hparams, Dataset) # Train the model print('\nBeginning training:', flush=True) now = datetime.datetime.now() trainer.fit(model, datamodule=datamodule) if hparams.verbose: elapsed = datetime.datetime.now() - now elapsed_fstr = time.strftime('%H:%M:%S', time.gmtime(elapsed.seconds)) print(f'\nTraining completed! Time Elapsed: {elapsed_fstr}', flush=True) # Record the best checkpoint if we kept track of it if not hparams.no_checkpoints: logger.log_hyperparams( {'best_checkpoint_path': checkpoint.best_model_path}) # Save the weights online if desired if hparams.save_weights_online: if hparams.verbose: print('\nSending weights to neptune servers...', flush=True) logger.log_artifact(checkpoint.best_model_path) if hparams.verbose: print('Finished.', flush=True) else: raise NotImplementedError # # Get all the experiments with the name hparams.name* # experiments = list(index.DIR_CHECKPOINTS.glob( # f'{hparams.name}_{hparams.exp_name}*')) # # import pdb; pdb.set_trace() # if len(experiments) > 1: # # Get the newest exp by v number # experiment_newest = sorted( # experiments, # key=lambda path: int(path.stem.split('_')[-1][1:]))[-1] # # Get the model with the best (lowest) val_loss # else: # experiment_newest = experiments[0] # experiment_newest_best_val = sorted( # experiment_newest.iterdir(), # key=lambda path: float( # path.stem.split('val_loss=')[-1].split('_')[0]))[0] # model = Model.load_from_checkpoint(str(experiment_newest_best_val)) # model.logger = logger # ## LOOK AT THIS LATER # model.prepare_data(val_path=const.DEFAULT_PATH) # # Define the trainer # trainer = pl.Trainer( # logger=model.logger, # gpus=hparams.gpus, # max_epochs=1, # ) if not hparams.no_test: # Ensure we are in cuda for testing if specified if 'cuda' in hparams.device and torch.cuda.is_available(): model.cuda() # Create the test data test_data = np.array( [datamodule.ds.array_data[n] for n in const.DEFAULT_PATH]).reshape( (1, len(const.DEFAULT_PATH), 2048)) torch_data = torch.Tensor(test_data) # Get the model outputs outs = model.forward(torch_data, output_mode='eval') outs.update({'errors': model.forward(torch_data, output_mode='error')}) # Visualize the test data figs = model.visualize(outs, borders=const.DEFAULT_BORDERS) if not hparams.no_graphs: for name, fig in figs.items(): # Doing logger.log_image(...) doesn't work for some reason model.logger.log_image(name, fig) # Close the neptune logger logger.experiment.stop()
def main(hparams): """ Main training routine specific for this project :param hparams: """ # 0 INIT TRACKER # https://docs.neptune.ai/integrations/pytorch_lightning.html try: import neptune NEPTUNE_AVAILABLE = True except ImportError: # pragma: no-cover NEPTUNE_AVAILABLE = False USE_NEPTUNE = False if getattr(hparams, 'tracker', None) is not None: if getattr(hparams.tracker, 'neptune', None) is not None: USE_NEPTUNE = True if USE_NEPTUNE and not NEPTUNE_AVAILABLE: warnings.warn( 'You want to use `neptune` logger which is not installed yet,' ' install it with `pip install neptune-client`.', UserWarning) time.sleep(5) tracker = None if NEPTUNE_AVAILABLE and USE_NEPTUNE: neptune_params = hparams.tracker.neptune fn_token = getattr(neptune_params, 'fn_token', None) if fn_token is not None: p = Path(neptune_params.fn_token).expanduser() if p.exists(): with open(p, 'r') as f: token = f.readline().splitlines()[0] os.environ['NEPTUNE_API_TOKEN'] = token hparams_flatten = dict_flatten(hparams, sep='.') experiment_name = hparams.tracker.get('experiment_name', None) tags = list(hparams.tracker.get('tags', [])) offline_mode = hparams.tracker.get('offline', False) tracker = NeptuneLogger( project_name=neptune_params.project_name, experiment_name=experiment_name, params=hparams_flatten, tags=tags, offline_mode=offline_mode, upload_source_files=["../../../*.py" ], # because hydra change current dir ) try: # log if tracker is not None: watermark_s = watermark(packages=[ 'python', 'nvidia', 'cudnn', 'hostname', 'torch', 'sparseconvnet', 'pytorch-lightning', 'hydra-core', 'numpy', 'plyfile' ]) log_text_as_artifact(tracker, watermark_s, "versions.txt") # arguments_of_script sysargs_s = str(sys.argv[1:]) log_text_as_artifact(tracker, sysargs_s, "arguments_of_script.txt") for key in ['overrides.yaml', 'config.yaml']: p = Path.cwd() / '.hydra' / key if p.exists(): tracker.log_artifact(str(p), f'hydra_{key}') callbacks = [] if tracker is not None: lr_logger = LearningRateLogger() callbacks.append(lr_logger) # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ model = LightningTemplateModel(hparams) if tracker is not None: s = str(model) log_text_as_artifact(tracker, s, "model_summary.txt") # ------------------------ # 2 INIT TRAINER # ------------------------ cfg = hparams.PL if tracker is None: tracker = cfg.logger # True by default in PL kwargs = dict(cfg) kwargs.pop('logger') trainer = pl.Trainer( max_epochs=hparams.train.max_epochs, callbacks=callbacks, logger=tracker, **kwargs, ) # ------------------------ # 3 START TRAINING # ------------------------ print() print("Start training") trainer.fit(model) except (Exception, KeyboardInterrupt) as ex: if tracker is not None: print_exc() tracker.experiment.stop(str(ex)) raise