def get_transform(train, args): if train: return presets.SegmentationPresetTrain(base_size=520, crop_size=480) elif not args.weights: return presets.SegmentationPresetEval(base_size=520) else: weights = PM.get_weight(args.weights) return weights.transforms()
def get_transform(train, args): if train: return presets.DetectionPresetTrain(args.data_augmentation) elif not args.weights: return presets.DetectionPresetEval() else: weights = PM.get_weight(args.weights) return weights.transforms()
def validate(model, args): val_datasets = args.val_dataset or [] if args.weights: weights = PM.get_weight(args.weights) preprocessing = weights.transforms() else: preprocessing = OpticalFlowPresetEval() for name in val_datasets: if name == "kitti": # Kitti has different image sizes so we need to individually pad them, we can't batch. # see comment in InputPadder if args.batch_size != 1 and args.rank == 0: warnings.warn( f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." ) val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing) _validate(model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1) elif name == "sintel": for pass_name in ("clean", "final"): val_dataset = Sintel(root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing) _validate( model, args, val_dataset, num_flow_updates=32, padder_mode="sintel", header=f"Sintel val {pass_name}", ) else: warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def main(args): if args.weights and PM is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) print("torch version: ", torch.__version__) print("torchvision version: ", torchvision.__version__) device = torch.device(args.device) torch.backends.cudnn.benchmark = True # Data loading code print("Loading data") traindir = os.path.join(args.data_path, args.train_dir) valdir = os.path.join(args.data_path, args.val_dir) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") dataset, _ = torch.load(cache_path) dataset.transform = transform_train else: if args.distributed: print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster") dataset = torchvision.datasets.Kinetics400( traindir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_train, frame_rate=15, extensions=( "avi", "mp4", ), ) if args.cache_dataset: print(f"Saving dataset_train to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) if not args.weights: transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) else: weights = PM.get_weight(args.weights) transform_test = weights.transforms() if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) dataset_test.transform = transform_test else: if args.distributed: print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster") dataset_test = torchvision.datasets.Kinetics400( valdir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_test, frame_rate=15, extensions=( "avi", "mp4", ), ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video) test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video) if args.distributed: train_sampler = DistributedSampler(train_sampler) test_sampler = DistributedSampler(test_sampler) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, collate_fn=collate_fn, ) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True, collate_fn=collate_fn, ) print("Creating model") if not args.weights: model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) else: model = PM.video.__dict__[args.model](weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() lr = args.lr * args.world_size optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) scaler = torch.cuda.amp.GradScaler() if args.amp else None # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs iters_per_epoch = len(data_loader) lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones] main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma) if args.lr_warmup_epochs > 0: warmup_iters = iters_per_epoch * args.lr_warmup_epochs args.lr_warmup_method = args.lr_warmup_method.lower() if args.lr_warmup_method == "linear": warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters ) elif args.lr_warmup_method == "constant": warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters ) else: raise RuntimeError( f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported." ) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters] ) else: lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module if args.resume: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if args.amp: scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "args": args, } if args.amp: checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f"Training time {total_time_str}")
def load_data(traindir, valdir, args): # Data loading code print("Loading data") val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size interpolation = InterpolationMode(args.interpolation) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print(f"Loading dataset_train from {cache_path}") dataset, _ = torch.load(cache_path) else: auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( crop_size=train_crop_size, interpolation=interpolation, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob, ), ) if args.cache_dataset: print(f"Saving dataset_train to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: if not args.weights: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation) else: weights = PM.get_weight(args.weights) preprocessing = weights.transforms() dataset_test = torchvision.datasets.ImageFolder( valdir, preprocessing, ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") if args.distributed: if args.ra_sampler: train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) else: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
def test_get_weight(name, weight): assert models.get_weight(name) == weight