def run_cli(): from pl_bolts.datamodules import VOCDetectionDataModule seed_everything(42) parser = ArgumentParser() parser = VOCDetectionDataModule.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser) parser = FasterRCNN.add_model_specific_args(parser) args = parser.parse_args() datamodule = VOCDetectionDataModule.from_argparse_args(args) args.num_classes = datamodule.num_classes model = FasterRCNN(**vars(args)) trainer = Trainer.from_argparse_args(args) trainer.fit(model, datamodule=datamodule)
def run_cli(): from pl_bolts.datamodules import VOCDetectionDataModule pl.seed_everything(42) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--data_dir", type=str, default=".") parser.add_argument("--batch_size", type=int, default=1) parser = FasterRCNN.add_model_specific_args(parser) args = parser.parse_args() datamodule = VOCDetectionDataModule.from_argparse_args(args) args.num_classes = datamodule.num_classes model = FasterRCNN(**vars(args)) trainer = pl.Trainer.from_argparse_args(args) trainer.fit(model, datamodule=datamodule)
def run_cli(): from argparse import ArgumentParser from pytorch_lightning import Trainer, seed_everything from pl_bolts.datamodules import VOCDetectionDataModule from pl_bolts.datamodules.vocdetection_datamodule import Compose from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration seed_everything(42) parser = ArgumentParser() parser.add_argument( "--config", type=str, metavar="PATH", required=True, help="read model configuration from PATH", ) parser.add_argument( "--darknet-weights", type=str, metavar="PATH", help="read the initial model weights from PATH in Darknet format", ) parser.add_argument( "--lr", type=float, metavar="LR", default=0.0013, help="learning rate after the warmup period", ) parser.add_argument( "--momentum", type=float, metavar="GAMMA", default=0.9, help="if nonzero, the optimizer uses momentum with factor GAMMA", ) parser.add_argument( "--weight-decay", type=float, metavar="LAMBDA", default=0.0005, help= "if nonzero, the optimizer uses weight decay (L2 penalty) with factor LAMBDA", ) parser.add_argument( "--warmup-epochs", type=int, metavar="N", default=1, help="learning rate warmup period is N epochs", ) parser.add_argument( "--max-epochs", type=int, metavar="N", default=300, help="train at most N epochs", ) parser.add_argument( "--initial-lr", type=float, metavar="LR", default=0.0, help="learning rate before the warmup period", ) parser.add_argument( "--confidence-threshold", type=float, metavar="THRESHOLD", default=0.001, help="keep predictions only if the confidence is above THRESHOLD", ) parser.add_argument( "--nms-threshold", type=float, metavar="THRESHOLD", default=0.45, help= "non-maximum suppression removes predicted boxes that have IoU greater than " "THRESHOLD with a higher scoring box", ) parser.add_argument( "--max-predictions-per-image", type=int, metavar="N", default=100, help="keep at most N best predictions", ) parser = VOCDetectionDataModule.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() config = YOLOConfiguration(args.config) transforms = [ lambda image, target: (F.to_tensor(image), target), Resize((config.height, config.width)) ] transforms = Compose(transforms) datamodule = VOCDetectionDataModule.from_argparse_args( args, train_transforms=transforms, val_transforms=transforms) optimizer_params = { "lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay } lr_scheduler_params = { "warmup_epochs": args.warmup_epochs, "max_epochs": args.max_epochs, "warmup_start_lr": args.initial_lr, } model = YOLO( network=config.get_network(), optimizer_params=optimizer_params, lr_scheduler_params=lr_scheduler_params, confidence_threshold=args.confidence_threshold, nms_threshold=args.nms_threshold, max_predictions_per_image=args.max_predictions_per_image, ) if args.darknet_weights is not None: with open(args.darknet_weights) as weight_file: model.load_darknet_weights(weight_file) trainer = Trainer.from_argparse_args(args) trainer.fit(model, datamodule=datamodule)