示例#1
0
def generic_train(
    model: BaseTransformer,
    args: argparse.Namespace,
    early_stopping_callback=None,
    logger=True,  # can pass WandbLogger() here
    extra_callbacks=[],
    checkpoint_callback=None,
    logging_callback=None,
    **extra_train_kwargs
):
    pl.seed_everything(args.seed)

    # init model
    odir = Path(model.hparams.output_dir)
    odir.mkdir(exist_ok=True)

    # add custom checkpoints
    if checkpoint_callback is None:
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
        )
    if early_stopping_callback:
        extra_callbacks.append(early_stopping_callback)
    if logging_callback is None:
        logging_callback = LoggingCallback()

    train_params = {}

    # TODO: remove with PyTorch 1.6 since pl uses native amp
    if args.fp16:
        train_params["precision"] = 16
        train_params["amp_level"] = args.fp16_opt_level

    if args.gpus > 1:
        train_params["accelerator"] = "ddp"

    train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
    # train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
    train_params["profiler"] = None  # extra_train_kwargs.get("profiler", None)

    trainer = pl.Trainer.from_argparse_args(
        args,
        weights_summary=None,
        callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback],
        logger=logger,
        plugins=[DDPPlugin(find_unused_parameters=True)],  # this is needed in new pytorch-lightning new version
        val_check_interval=1,
        num_sanity_val_steps=2,
        **train_params,
    )

    if args.do_train:
        trainer.fit(model)

    # else:
    #     print("RAG modeling tests with new set functions successfuly executed!")
    return trainer
def test_v1_6_0_ddp_sync_batchnorm():
    with pytest.deprecated_call(
            match=
            "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
        DDPPlugin(sync_batchnorm=False)
def test_v1_6_0_ddp_num_nodes():
    with pytest.deprecated_call(
            match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"):
        DDPPlugin(num_nodes=1)
def test_v1_6_0_ddp_plugin_task_idx():
    plugin = DDPPlugin()
    with pytest.deprecated_call(match='Use `DDPPlugin.local_rank` instead'):
        _ = plugin.task_idx
def train_default_zoobot_from_scratch(
        # absolutely crucial arguments
        save_dir,  # save model here
        schema,  # answer these questions
        # input data - specify *either* catalog (to be split) or the splits themselves
    catalog=None,
        train_catalog=None,
        val_catalog=None,
        test_catalog=None,
        # model training parameters
        model_architecture='efficientnet',
        batch_size=256,
        epochs=1000,
        patience=8,
        # data and augmentation parameters
        # datamodule_class=GalaxyDataModule,  # generic catalog of galaxies, will not download itself. Can replace with any datamodules from pytorch_galaxy_datasets
        color=False,
        resize_size=224,
        crop_scale_bounds=(0.7, 0.8),
        crop_ratio_bounds=(0.9, 1.1),
        # hardware parameters
        accelerator='auto',
        nodes=1,
        gpus=2,
        num_workers=4,
        prefetch_factor=4,
        mixed_precision=False,
        # replication parameters
        random_state=42,
        wandb_logger=None):

    slurm_debugging_logs()

    pl.seed_everything(random_state)

    assert save_dir is not None
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    if color:
        logging.warning(
            'Training on color images, not converting to greyscale')
        channels = 3
    else:
        logging.info('Converting images to greyscale before training')
        channels = 1

    strategy = None
    if (gpus is not None) and (gpus > 1):
        # only works as plugins, not strategy
        # strategy = 'ddp'
        strategy = DDPPlugin(find_unused_parameters=False)
        logging.info('Using multi-gpu training')

    if nodes > 1:
        assert gpus == 2
        logging.info('Using multi-node training')
        # this hangs silently on Manchester's slurm cluster - perhaps you will have more success?

    precision = 32
    if mixed_precision:
        logging.info(
            'Training with automatic mixed precision. Will reduce memory footprint but may cause training instability for e.g. resnet'
        )
        precision = 16

    assert num_workers > 0

    if (gpus is not None) and (num_workers * gpus > os.cpu_count()):
        logging.warning("""num_workers * gpu > num cpu.
            You may be spawning more dataloader workers than you have cpus, causing bottlenecks. 
            Suggest reducing num_workers.""")
    if num_workers > os.cpu_count():
        logging.warning("""num_workers > num cpu.
            You may be spawning more dataloader workers than you have cpus, causing bottlenecks. 
            Suggest reducing num_workers.""")

    if catalog is not None:
        assert train_catalog is None
        assert val_catalog is None
        assert test_catalog is None
        catalogs_to_use = {'catalog': catalog}
    else:
        assert catalog is None
        catalogs_to_use = {
            'train_catalog': train_catalog,
            'val_catalog': val_catalog,
            'test_catalog': test_catalog
        }

    datamodule = GalaxyDataModule(
        label_cols=schema.label_cols,
        # can take either a catalog (and split it), or a pre-split catalog
        **catalogs_to_use,
        #   augmentations parameters
        album=False,
        greyscale=not color,
        resize_size=resize_size,
        crop_scale_bounds=crop_scale_bounds,
        crop_ratio_bounds=crop_ratio_bounds,
        #   hardware parameters
        batch_size=
        batch_size,  # on 2xA100s, 256 with DDP, 512 with distributed (i.e. split batch)
        num_workers=num_workers,
        prefetch_factor=prefetch_factor)
    datamodule.setup()

    get_architecture, representation_dim = select_base_architecture_func_from_name(
        model_architecture)

    model = define_model.get_plain_pytorch_zoobot_model(
        output_dim=len(schema.answers),
        include_top=True,
        channels=channels,
        get_architecture=get_architecture,
        representation_dim=representation_dim)

    # This just adds schema.question_index_groups as an arg to the usual (labels, preds) loss arg format
    # Would use lambda but multi-gpu doesn't support as lambda can't be pickled
    def loss_func(preds, labels):  # pytorch convention is preds, labels
        return losses.calculate_multiquestion_loss(
            labels, preds, schema.question_index_groups
        )  # my and sklearn convention is labels, preds

    lightning_model = define_model.GenericLightningModule(model, loss_func)

    callbacks = [
        ModelCheckpoint(dirpath=os.path.join(save_dir, 'checkpoints'),
                        monitor="val_loss",
                        save_weights_only=True,
                        mode='min',
                        save_top_k=3),
        EarlyStopping(monitor='val_loss', patience=patience, check_finite=True)
    ]

    trainer = pl.Trainer(
        log_every_n_steps=3,
        accelerator=accelerator,
        gpus=gpus,  # per node
        num_nodes=nodes,
        strategy=strategy,
        precision=precision,
        logger=wandb_logger,
        callbacks=callbacks,
        max_epochs=epochs,
        default_root_dir=save_dir)

    logging.info((trainer.training_type_plugin, trainer.world_size,
                  trainer.local_rank, trainer.global_rank, trainer.node_rank))

    trainer.fit(lightning_model, datamodule)

    trainer.test(
        model=lightning_model,
        datamodule=datamodule,
        ckpt_path=
        'best'  # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
    )