Beispiel #1
0
def main(parser):
    parser.add_argument('-m', '--model', type=str, default='PredNet')
    parser.add_argument('-d',
                        '--dataset',
                        type=str,
                        default='SchapiroResnetEmbeddingDataset')
    parser.add_argument('--load_model', action='store_true')
    parser.add_argument('--ipy', action='store_true')
    parser.add_argument('--no_graphs', action='store_true')
    parser.add_argument('--no_test', action='store_true')
    parser.add_argument('--user', type=str, default='aprashedahmed')
    parser.add_argument('-p', '--project', type=str, default='sandbox')
    parser.add_argument('-t', '--tags', nargs='+')
    parser.add_argument('--no_checkpoints', action='store_true')
    parser.add_argument('--offline_mode', action='store_true')
    parser.add_argument('--save_weights_online', action='store_true')

    parser.add_argument('--test_checkpoints', action='store_true')
    parser.add_argument('--test_epochs', type=int, default=2)
    parser.add_argument('--test_n_paths', type=int, default=2)
    parser.add_argument('--test_online', action='store_true')
    parser.add_argument('--test_project', type=str, default='')
    parser.add_argument('-v', '--verbose', action='store_true')

    parser.add_argument('--n_workers', type=int, default=1)
    parser.add_argument('-e', '--epochs', type=int, default=50)
    parser.add_argument('--gpus', type=float, default=1)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('-s', '--seed', type=str, default='random')
    parser.add_argument('-b', '--batch_size', type=int, default=256 + 128)
    parser.add_argument('--n_val', type=int, default=1)
    parser.add_argument('--mapping', type=str, default='random')

    parser.add_argument('--dir_checkpoints',
                        type=str,
                        default=str(index.DIR_CHECKPOINTS))
    parser.add_argument('--checkpoint_period', type=float, default=1.0)
    parser.add_argument('--val_check_interval', type=float, default=1.0)
    parser.add_argument('--save_top_k', type=float, default=1)
    parser.add_argument('--early_stop_mode', type=str, default='min')
    parser.add_argument('--early_stop_patience', type=int, default=10)
    parser.add_argument('--early_stop_min_delta', type=float, default=0.001)

    parser.add_argument('--name', type=str, default='')
    parser.add_argument('--exp_prefix', type=str, default='')
    parser.add_argument('--exp_suffix', type=str, default='')

    # Get Model and Dataset specific args
    temp_args, _ = parser.parse_known_args()

    # Make sure this is correct
    if hasattr(datasets, temp_args.dataset):
        Dataset = getattr(datasets, temp_args.dataset)
        parser = Dataset.add_dataset_specific_args(parser)
    else:
        raise Exception(
            f'Invalid dataset "{temp_args.dataset}" passed. Check it is '
            f'importable: "from prevseg.datasets import {temp_args.dataset}"')

    # Get temp args now with dataset args added
    temp_args, _ = parser.parse_known_args()

    # Check this is correct as well
    if hasattr(models, temp_args.model):
        Model = getattr(models, temp_args.model)
        parser = Model.add_model_specific_args(parser)
    else:
        raise Exception(
            f'Invalid model "{temp_args.model}" passed. Check it is importable:'
            f' "from prevseg.models import {temp_args.model}"')

    # Get the parser and turn into an omegaconf
    hparams = parser.parse_args()

    # If we are test-running, do a few things differently (scale down dataset,
    # send to sandbox project, etc.)
    if hparams.test_run:
        hparams.epochs = hparams.test_epochs
        hparams.n_paths = hparams.test_n_paths
        hparams.name = '_'.join(filter(None, ['test_run', hparams.exp_prefix]))
        hparams.project = hparams.test_project or 'sandbox'
        hparams.verbose = True
        hparams.ipdb = True
        hparams.no_checkpoints = not hparams.test_checkpoints
        hparams.offline_mode = not hparams.test_online

    # Seed is a string to allow for None/random as an input. Make it passable
    # to pl.seed_everything
    hparams.seed = None if 'None' in hparams.seed or hparams.seed == 'random' \
        else int(hparams.seed)

    # Get the hostname for book keeping
    hparams.hostname = socket.gethostname()

    # Set the seed
    hparams.seed = pl.seed_everything(hparams.seed)

    # Turn the string entry for mapping into a dict (that is also a str)
    if hparams.mapping == 'default':
        hparams.mapping = const.DEFAULT_MAPPING
    elif hparams.mapping == 'random':
        hparams.mapping = str(
            Dataset.random_mapping(n_pentagons=hparams.n_pentagons))
    else:
        raise ValueError(f'Invalid entry for mapping: {hparams.mapping}')

    # Set the validation path
    hparams.val_path = str(const.DEFAULT_PATH)

    # Create experiment name
    hparams.name = name_from_hparams(hparams)
    hparams.exp_name = name_from_hparams(hparams, short=True)
    if hparams.verbose:
        print(f'Beginning experiment: "{hparams.name}"')

    # Neptune Logger
    logger = NeptuneLogger(
        project_name=f"{hparams.user}/{hparams.project}",
        experiment_name=hparams.exp_name,
        params=vars(hparams),
        tags=hparams.tags,
        offline_mode=hparams.offline_mode,
        upload_source_files=[
            str(Path(__file__).resolve()),
            inspect.getfile(Model),
            inspect.getfile(Dataset)
        ],
        close_after_fit=False,
    )

    if not hparams.load_model:
        # Checkpoint Call back
        if hparams.no_checkpoints:
            checkpoint = False
            if hparams.verbose:
                print('\nNot saving any checkpoints.\n', flush=True)
        else:
            dir_checkpoints_experiment = (Path(hparams.dir_checkpoints) /
                                          hparams.name)
            if not dir_checkpoints_experiment.exists():
                dir_checkpoints_experiment.mkdir(parents=True)

            checkpoint = pl.callbacks.ModelCheckpoint(
                filepath=str(
                    dir_checkpoints_experiment /
                    (f'seed={hparams.seed}' + '_{epoch}_{val_loss:.3f}')),
                verbose=hparams.verbose,
                save_top_k=hparams.save_top_k,
                period=hparams.checkpoint_period,
            )

        # Early stopping callback
        early_stop_callback = pl.callbacks.EarlyStopping(
            monitor='val_loss',
            min_delta=hparams.early_stop_min_delta,
            patience=hparams.early_stop_patience,
            verbose=hparams.verbose,
            mode=hparams.early_stop_mode,
        )

        # Define the trainer
        trainer = pl.Trainer(
            checkpoint_callback=checkpoint,
            max_epochs=hparams.epochs,
            logger=logger,
            val_check_interval=hparams.val_check_interval,
            gpus=hparams.gpus,
            early_stop_callback=early_stop_callback,
        )

        # Verbose messaging
        if hparams.verbose:
            now = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S")
            print(f'\nCurrent time: {now}', flush=True)
            print(f'\nRunning with following hparams:', flush=True)
            pprint(vars(hparams))

        # Define the model
        model = Model(hparams)
        if hparams.verbose:
            print(f'\nModel being used: \n{model}', flush=True)

        # Define the datamodule
        datamodule = datasets.DataModuleConstructor(hparams, Dataset)

        # Train the model
        print('\nBeginning training:', flush=True)
        now = datetime.datetime.now()
        trainer.fit(model, datamodule=datamodule)
        if hparams.verbose:
            elapsed = datetime.datetime.now() - now
            elapsed_fstr = time.strftime('%H:%M:%S',
                                         time.gmtime(elapsed.seconds))
            print(f'\nTraining completed! Time Elapsed: {elapsed_fstr}',
                  flush=True)

        # Record the best checkpoint if we kept track of it
        if not hparams.no_checkpoints:
            logger.log_hyperparams(
                {'best_checkpoint_path': checkpoint.best_model_path})
            # Save the weights online if desired
            if hparams.save_weights_online:
                if hparams.verbose:
                    print('\nSending weights to neptune servers...',
                          flush=True)
                logger.log_artifact(checkpoint.best_model_path)
                if hparams.verbose:
                    print('Finished.', flush=True)

    else:
        raise NotImplementedError
        # # Get all the experiments with the name hparams.name*
        # experiments = list(index.DIR_CHECKPOINTS.glob(
        #     f'{hparams.name}_{hparams.exp_name}*'))

        # # import pdb; pdb.set_trace()
        # if len(experiments) > 1:
        #     # Get the newest exp by v number
        #     experiment_newest = sorted(
        #         experiments,
        #         key=lambda path: int(path.stem.split('_')[-1][1:]))[-1]
        #     # Get the model with the best (lowest) val_loss
        # else:
        #     experiment_newest = experiments[0]
        # experiment_newest_best_val = sorted(
        #     experiment_newest.iterdir(),
        #     key=lambda path: float(
        #         path.stem.split('val_loss=')[-1].split('_')[0]))[0]

        # model = Model.load_from_checkpoint(str(experiment_newest_best_val))
        # model.logger = logger
        # ## LOOK AT THIS LATER
        # model.prepare_data(val_path=const.DEFAULT_PATH)

        # # Define the trainer
        # trainer = pl.Trainer(
        #     logger=model.logger,
        #     gpus=hparams.gpus,
        #     max_epochs=1,
        # )

    if not hparams.no_test:
        # Ensure we are in cuda for testing if specified
        if 'cuda' in hparams.device and torch.cuda.is_available():
            model.cuda()

        # Create the test data
        test_data = np.array(
            [datamodule.ds.array_data[n] for n in const.DEFAULT_PATH]).reshape(
                (1, len(const.DEFAULT_PATH), 2048))
        torch_data = torch.Tensor(test_data)

        # Get the model outputs
        outs = model.forward(torch_data, output_mode='eval')
        outs.update({'errors': model.forward(torch_data, output_mode='error')})

        # Visualize the test data
        figs = model.visualize(outs, borders=const.DEFAULT_BORDERS)
        if not hparams.no_graphs:
            for name, fig in figs.items():
                # Doing logger.log_image(...) doesn't work for some reason
                model.logger.log_image(name, fig)

    # Close the neptune logger
    logger.experiment.stop()
Beispiel #2
0
def main(hparams):
    """
    Main training routine specific for this project
    :param hparams:

    """

    # 0 INIT TRACKER
    # https://docs.neptune.ai/integrations/pytorch_lightning.html
    try:
        import neptune
        NEPTUNE_AVAILABLE = True
    except ImportError:  # pragma: no-cover
        NEPTUNE_AVAILABLE = False

    USE_NEPTUNE = False
    if getattr(hparams, 'tracker', None) is not None:
        if getattr(hparams.tracker, 'neptune', None) is not None:
            USE_NEPTUNE = True

    if USE_NEPTUNE and not NEPTUNE_AVAILABLE:
        warnings.warn(
            'You want to use `neptune` logger which is not installed yet,'
            ' install it with `pip install neptune-client`.', UserWarning)
        time.sleep(5)

    tracker = None

    if NEPTUNE_AVAILABLE and USE_NEPTUNE:
        neptune_params = hparams.tracker.neptune
        fn_token = getattr(neptune_params, 'fn_token', None)
        if fn_token is not None:
            p = Path(neptune_params.fn_token).expanduser()
            if p.exists():
                with open(p, 'r') as f:
                    token = f.readline().splitlines()[0]
                    os.environ['NEPTUNE_API_TOKEN'] = token

        hparams_flatten = dict_flatten(hparams, sep='.')
        experiment_name = hparams.tracker.get('experiment_name', None)
        tags = list(hparams.tracker.get('tags', []))
        offline_mode = hparams.tracker.get('offline', False)

        tracker = NeptuneLogger(
            project_name=neptune_params.project_name,
            experiment_name=experiment_name,
            params=hparams_flatten,
            tags=tags,
            offline_mode=offline_mode,
            upload_source_files=["../../../*.py"
                                 ],  # because hydra change current dir
        )

    try:

        # log
        if tracker is not None:
            watermark_s = watermark(packages=[
                'python', 'nvidia', 'cudnn', 'hostname', 'torch',
                'sparseconvnet', 'pytorch-lightning', 'hydra-core', 'numpy',
                'plyfile'
            ])
            log_text_as_artifact(tracker, watermark_s, "versions.txt")
            # arguments_of_script
            sysargs_s = str(sys.argv[1:])
            log_text_as_artifact(tracker, sysargs_s, "arguments_of_script.txt")

            for key in ['overrides.yaml', 'config.yaml']:
                p = Path.cwd() / '.hydra' / key
                if p.exists():
                    tracker.log_artifact(str(p), f'hydra_{key}')

        callbacks = []
        if tracker is not None:
            lr_logger = LearningRateLogger()
            callbacks.append(lr_logger)

        # ------------------------
        # 1 INIT LIGHTNING MODEL
        # ------------------------
        model = LightningTemplateModel(hparams)

        if tracker is not None:
            s = str(model)
            log_text_as_artifact(tracker, s, "model_summary.txt")

        # ------------------------
        # 2 INIT TRAINER
        # ------------------------
        cfg = hparams.PL

        if tracker is None:
            tracker = cfg.logger  # True by default in PL

        kwargs = dict(cfg)
        kwargs.pop('logger')

        trainer = pl.Trainer(
            max_epochs=hparams.train.max_epochs,
            callbacks=callbacks,
            logger=tracker,
            **kwargs,
        )

        # ------------------------
        # 3 START TRAINING
        # ------------------------
        print()
        print("Start training")

        trainer.fit(model)

    except (Exception, KeyboardInterrupt) as ex:
        if tracker is not None:
            print_exc()
            tracker.experiment.stop(str(ex))
        raise