コード例 #1
0
def run_cli():
    from pl_bolts.datamodules import VOCDetectionDataModule

    seed_everything(42)
    parser = ArgumentParser()
    parser = VOCDetectionDataModule.add_argparse_args(parser)
    parser = Trainer.add_argparse_args(parser)
    parser = FasterRCNN.add_model_specific_args(parser)

    args = parser.parse_args()

    datamodule = VOCDetectionDataModule.from_argparse_args(args)
    args.num_classes = datamodule.num_classes

    model = FasterRCNN(**vars(args))
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=datamodule)
コード例 #2
0
def run_cli():
    from pl_bolts.datamodules import VOCDetectionDataModule

    pl.seed_everything(42)
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument("--data_dir", type=str, default=".")
    parser.add_argument("--batch_size", type=int, default=1)
    parser = FasterRCNN.add_model_specific_args(parser)

    args = parser.parse_args()

    datamodule = VOCDetectionDataModule.from_argparse_args(args)
    args.num_classes = datamodule.num_classes

    model = FasterRCNN(**vars(args))
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=datamodule)
コード例 #3
0
def run_cli():
    from argparse import ArgumentParser

    from pytorch_lightning import Trainer, seed_everything

    from pl_bolts.datamodules import VOCDetectionDataModule
    from pl_bolts.datamodules.vocdetection_datamodule import Compose
    from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration

    seed_everything(42)

    parser = ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        metavar="PATH",
        required=True,
        help="read model configuration from PATH",
    )
    parser.add_argument(
        "--darknet-weights",
        type=str,
        metavar="PATH",
        help="read the initial model weights from PATH in Darknet format",
    )
    parser.add_argument(
        "--lr",
        type=float,
        metavar="LR",
        default=0.0013,
        help="learning rate after the warmup period",
    )
    parser.add_argument(
        "--momentum",
        type=float,
        metavar="GAMMA",
        default=0.9,
        help="if nonzero, the optimizer uses momentum with factor GAMMA",
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        metavar="LAMBDA",
        default=0.0005,
        help=
        "if nonzero, the optimizer uses weight decay (L2 penalty) with factor LAMBDA",
    )
    parser.add_argument(
        "--warmup-epochs",
        type=int,
        metavar="N",
        default=1,
        help="learning rate warmup period is N epochs",
    )
    parser.add_argument(
        "--max-epochs",
        type=int,
        metavar="N",
        default=300,
        help="train at most N epochs",
    )
    parser.add_argument(
        "--initial-lr",
        type=float,
        metavar="LR",
        default=0.0,
        help="learning rate before the warmup period",
    )
    parser.add_argument(
        "--confidence-threshold",
        type=float,
        metavar="THRESHOLD",
        default=0.001,
        help="keep predictions only if the confidence is above THRESHOLD",
    )
    parser.add_argument(
        "--nms-threshold",
        type=float,
        metavar="THRESHOLD",
        default=0.45,
        help=
        "non-maximum suppression removes predicted boxes that have IoU greater than "
        "THRESHOLD with a higher scoring box",
    )
    parser.add_argument(
        "--max-predictions-per-image",
        type=int,
        metavar="N",
        default=100,
        help="keep at most N best predictions",
    )

    parser = VOCDetectionDataModule.add_argparse_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    config = YOLOConfiguration(args.config)

    transforms = [
        lambda image, target: (F.to_tensor(image), target),
        Resize((config.height, config.width))
    ]
    transforms = Compose(transforms)
    datamodule = VOCDetectionDataModule.from_argparse_args(
        args, train_transforms=transforms, val_transforms=transforms)

    optimizer_params = {
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    }
    lr_scheduler_params = {
        "warmup_epochs": args.warmup_epochs,
        "max_epochs": args.max_epochs,
        "warmup_start_lr": args.initial_lr,
    }
    model = YOLO(
        network=config.get_network(),
        optimizer_params=optimizer_params,
        lr_scheduler_params=lr_scheduler_params,
        confidence_threshold=args.confidence_threshold,
        nms_threshold=args.nms_threshold,
        max_predictions_per_image=args.max_predictions_per_image,
    )
    if args.darknet_weights is not None:
        with open(args.darknet_weights) as weight_file:
            model.load_darknet_weights(weight_file)

    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=datamodule)