def get_imagenet_test_data_loader(args): valdir = os.path.join(args.data_path, "val") resize_size, crop_size, interpolation = get_transform_params(args.arch) cache_path = _get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) else: dataset_test = torchvision.datasets.ImageFolder( valdir, presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation) ) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") test_sampler = torch.utils.data.SequentialSampler(dataset_test) imagenet_test_data_loader = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True ) return imagenet_test_data_loader
def load_data(traindir, valdir, cache_dataset, distributed): # Data loading code print("Loading data") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) if cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: dataset = torchvision.datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) if cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) else: dataset_test = torchvision.datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) if cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") if distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
def main(args): '''data_loader &dataset''' dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path) train_sampler = torch.utils.data.RandomSampler(dataset) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) print("Creating model") model = torchvision.models.detection.__dict__[args.model]( num_classes=num_classes, pretrained=args.pretrained, ) print(model) device = torch.device(args.device) model.to(device) '''optimizer&&lr_scheduler''' params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) # TO DO:resume &distributed print("Start training") start_time = time.time() for epoch in range(args.epochs): train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() if args.output_dir: utils.save_on_master( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'args': args }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def load_data(traindir, valdir, args): # Data loading code print("Loading data") resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob)) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) else: dataset_test = torchvision.datasets.ImageFolder( valdir, presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size)) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
def save_chk_point(model_without_ddp, optimizer, lr_scheduler, epoch, acc5): global args global best_acc5 if acc5 > best_acc5: best_acc5 = acc5 checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'ckp_{}.pth'.format(acc5)))
def main(args): if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True)) # dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] scale_min = 0.5 scale_max = 1.75 rotate_min = -1 rotate_max = 1 train_h = 512 train_w = 1024 ignore_label = 255 train_transform = T.Compose([ T.RandScale([scale_min, scale_max]), T.RandRotate([rotate_min, rotate_max], padding=mean, ignore_label=ignore_label), T.RandomGaussianBlur(), T.RandomHorizontalFlip(), T.Crop([train_h, train_w], crop_type='rand', padding=mean, ignore_label=ignore_label), T.ToTensor(), T.Normalize(mean=mean, std=std) ]) dataset_train = dataset.CityscapesData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) val_transform = T.Compose([ T.Crop([train_h, train_w], crop_type='center', padding=mean, ignore_label=ignore_label), T.ToTensor(), T.Normalize(mean=mean, std=std) ]) dataset_test = dataset.CityscapesData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset_train) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader = torch.utils.data.DataLoader( dataset_train, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, collate_fn=utils.collate_fn, drop_last=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) num_classes = 19 # model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, # aux_loss=args.aux_loss, # pretrained=args.pretrained) if args.pretrained: supernet = OFAMobileNetV3( n_classes=1000, dropout_rate=0, width_mult_list=1.2, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4], ) arch = OFAArchitecture.from_legency_string(args.arch) supernet.set_active_subnet(ks=arch.ks, e=arch.ratios, d=arch.depths) model = supernet.get_active_subnet() s = torch.load("model_best.pth.tar", map_location="cpu") model.load_state_dict(s["state_dict_ema"]) model = convert2segmentation(model=model, begin_index_index=17) print("load pretrained model.") else: supernet = SPOSMobileNetV3Segmentation(width_mult=1.2) model = supernet.get_subnet(OFAArchitecture.from_legency_string(args.arch)) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module if args.pretrained: params_to_optimize = [ {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, {"params": [p for p in model_without_ddp.stem.parameters() if p.requires_grad]}, ] if args.aux_loss: params = [p for p in model_without_ddp.classifier.parameters() if p.requires_grad] params_to_optimize.append({"params": params, "lr": args.lr * 10}) optimizer = torch.optim.SGD( params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: params_to_optimize = [ {"params": [p for p in model_without_ddp.first_conv.parameters() if p.requires_grad]}, {"params": [p for p in model_without_ddp.blocks.parameters() if p.requires_grad]}, {"params": [p for p in model_without_ddp.remain_block.parameters() if p.requires_grad]}, {"params": [p for p in model_without_ddp.head.parameters() if p.requires_grad]}, ] if args.aux_loss: params = [p for p in model_without_ddp.aux_head.parameters() if p.requires_grad] params_to_optimize.append({"params": params, "lr": args.lr}) optimizer = torch.optim.SGD( params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only) if not args.test_only: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) return start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True)) dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False)) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, collate_fn=utils.collate_fn, drop_last=True) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) model = torchvision.models.segmentation.__dict__[args.model]( num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module params_to_optimize = [ { "params": [ p for p in model_without_ddp.backbone.parameters() if p.requires_grad ] }, { "params": [ p for p in model_without_ddp.classifier.parameters() if p.requires_grad ] }, ] if args.aux_loss: params = [ p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad ] params_to_optimize.append({"params": params, "lr": args.lr * 10}) optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs))**0.9) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only) if not args.test_only: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) return start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=int( 1.5 * args.batch_size), shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, ) # TODO: finetuning model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print("Start training") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) if args.use_deterministic_algorithms: torch.use_deterministic_algorithms(True) # Data loading code print("Loading data") dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) if args.aspect_ratio_group_factor >= 0: group_ids = create_aspect_ratio_groups( dataset, k=args.aspect_ratio_group_factor) train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) else: train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) print("Creating model") kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} if args.data_augmentation in ["multiscale", "lsj"]: kwargs["_skip_resize"] = True if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh model = torchvision.models.detection.__dict__[args.model]( weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module if args.norm_weight_decay is None: parameters = [p for p in model.parameters() if p.requires_grad] else: param_groups = torchvision.ops._utils.split_normalization_params(model) wd_groups = [args.norm_weight_decay, args.weight_decay] parameters = [{ "params": p, "weight_decay": w } for p, w in zip(param_groups, wd_groups) if p] opt_name = args.opt.lower() if opt_name.startswith("sgd"): optimizer = torch.optim.SGD( parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov="nesterov" in opt_name, ) elif opt_name == "adamw": optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) else: raise RuntimeError( f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.") scaler = torch.cuda.amp.GradScaler() if args.amp else None args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "multisteplr": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) elif args.lr_scheduler == "cosineannealinglr": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs) else: raise RuntimeError( f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported." ) if args.resume: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if args.amp: scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: torch.backends.cudnn.deterministic = True evaluate(model, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler) lr_scheduler.step() if args.output_dir: checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "args": args, "epoch": epoch, } if args.amp: checkpoint["scaler"] = scaler.state_dict() utils.save_on_master( checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master( checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) # evaluate after every epoch evaluate(model, data_loader_test, device=device) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f"Training time {total_time_str}")
def train_dino(args): utils.init_distributed_mode(args) utils.fix_random_seeds(args.seed) print("git:\n {}\n".format(utils.get_sha())) print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) cudnn.benchmark = True # ============ preparing data ... ============ transform = DataAugmentationDINO( args.global_crops_scale, args.local_crops_scale, args.local_crops_number, ) #dataset = datasets.ImageFolder(args.data_path, transform=transform) from sen12ms import get_transform dataset = AllSen12MSDataset(args.data_path, "train", transform=transform, tansform_coord=None, classes=None, seasons=None, split_by_region=True, download=False) sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=args.batch_size_per_gpu, num_workers=args.num_workers, pin_memory=True, drop_last=True, ) print(f"Data loaded: there are {len(dataset)} images.") # ============ building student and teacher networks ... ============ # if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base) if args.arch in vits.__dict__.keys(): student = vits.__dict__[args.arch]( patch_size=args.patch_size, drop_path_rate=0.1, # stochastic depth ) teacher = vits.__dict__[args.arch](patch_size=args.patch_size) embed_dim = student.embed_dim student = utils.replace_input_layer(student, inchannels=13) teacher = utils.replace_input_layer(teacher, inchannels=13) # otherwise, we check if the architecture is in torchvision models elif args.arch in torchvision_models.__dict__.keys(): student = torchvision_models.__dict__[args.arch]() teacher = torchvision_models.__dict__[args.arch]() embed_dim = student.fc.weight.shape[1] else: print(f"Unknow architecture: {args.arch}") # multi-crop wrapper handles forward with inputs of different resolutions student = utils.MultiCropWrapper( student, DINOHead( embed_dim, args.out_dim, use_bn=args.use_bn_in_head, norm_last_layer=args.norm_last_layer, )) teacher = utils.MultiCropWrapper( teacher, DINOHead(embed_dim, args.out_dim, args.use_bn_in_head), ) # move networks to gpu student, teacher = student.cuda(), teacher.cuda() # synchronize batch norms (if any) if utils.has_batchnorms(student): student = nn.SyncBatchNorm.convert_sync_batchnorm(student) teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher) # we need DDP wrapper to have synchro batch norms working... teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu]) teacher_without_ddp = teacher.module else: # teacher_without_ddp and teacher are the same thing teacher_without_ddp = teacher student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu]) # teacher and student start with the same weights teacher_without_ddp.load_state_dict(student.module.state_dict()) # there is no backpropagation through the teacher, so no need for gradients for p in teacher.parameters(): p.requires_grad = False print(f"Student and Teacher are built: they are both {args.arch} network.") # ============ preparing loss ... ============ dino_loss = DINOLoss( args.out_dim, args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number args.warmup_teacher_temp, args.teacher_temp, args.warmup_teacher_temp_epochs, args.epochs, ).cuda() # ============ preparing optimizer ... ============ params_groups = utils.get_params_groups(student) if args.optimizer == "adamw": optimizer = torch.optim.AdamW(params_groups) # to use with ViTs elif args.optimizer == "sgd": optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler elif args.optimizer == "lars": optimizer = utils.LARS( params_groups) # to use with convnet and large batches # for mixed precision training fp16_scaler = None if args.use_fp16: fp16_scaler = torch.cuda.amp.GradScaler() # ============ init schedulers ... ============ lr_schedule = utils.cosine_scheduler( args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule args.min_lr, args.epochs, len(data_loader), warmup_epochs=args.warmup_epochs, ) wd_schedule = utils.cosine_scheduler( args.weight_decay, args.weight_decay_end, args.epochs, len(data_loader), ) # momentum parameter is increased to 1. during training with a cosine schedule momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, args.epochs, len(data_loader)) print(f"Loss, optimizer and schedulers ready.") # ============ optionally resume training ... ============ to_restore = {"epoch": 0} utils.restart_from_checkpoint( os.path.join(args.output_dir, "checkpoint.pth"), run_variables=to_restore, student=student, teacher=teacher, optimizer=optimizer, fp16_scaler=fp16_scaler, dino_loss=dino_loss, ) start_epoch = to_restore["epoch"] start_time = time.time() print("Starting DINO training !") for epoch in range(start_epoch, args.epochs): data_loader.sampler.set_epoch(epoch) # ============ training one epoch of DINO ... ============ train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule, epoch, fp16_scaler, args) # ============ writing logs ... ============ save_dict = { 'student': student.state_dict(), 'teacher': teacher.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch + 1, 'args': args, 'dino_loss': dino_loss.state_dict(), } if fp16_scaler is not None: save_dict['fp16_scaler'] = fp16_scaler.state_dict() utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth')) if args.saveckp_freq and epoch % args.saveckp_freq == 0: utils.save_on_master( save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth')) log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch } if utils.is_main_process(): with (Path(args.output_dir) / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.prototype and prototype is None: raise ImportError( "The prototype module couldn't be found. Please install the latest torchvision nightly." ) if not args.prototype and args.weights: raise ValueError( "The weights parameter works only in prototype mode. Please pass the --prototype argument." ) if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # Data loading code print("Loading data") dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) if args.aspect_ratio_group_factor >= 0: group_ids = create_aspect_ratio_groups( dataset, k=args.aspect_ratio_group_factor) train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) else: train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) print("Creating model") kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh if not args.prototype: model = torchvision.models.detection.__dict__[args.model]( pretrained=args.pretrained, num_classes=num_classes, **kwargs) else: model = prototype.models.detection.__dict__[args.model]( weights=args.weights, num_classes=num_classes, **kwargs) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scaler = torch.cuda.amp.GradScaler() if args.amp else None args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "multisteplr": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) elif args.lr_scheduler == "cosineannealinglr": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs) else: raise RuntimeError( f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported." ) if args.resume: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if args.amp: scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: evaluate(model, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler) lr_scheduler.step() if args.output_dir: checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "args": args, "epoch": epoch, } if args.amp: checkpoint["scaler"] = scaler.state_dict() utils.save_on_master( checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master( checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) # evaluate after every epoch evaluate(model, data_loader_test, device=device) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f"Training time {total_time_str}")
def main(): args = get_args() if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) # Data loading print("Loading data") dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True)) dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False)) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) if args.aspect_ratio_group_factor >= 0: group_ids = create_aspect_ratio_groups( dataset, k=args.aspect_ratio_group_factor) train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.b) else: train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.b, drop_last=True) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.b, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) # Model creating print("Creating model") # model = models.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained) model = torchvision.models.detection.__dict__[args.model]( num_classes=num_classes, pretrained=args.pretrained) device = torch.device(args.device) model.to(device) # Distribute model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module # Parallel if args.parallel: print('Training parallel') model = torch.nn.DataParallel(model, device_ids=[args.gpu]).cuda() model_without_ddp = model.module # Optimizer params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) # Resume training if args.resume: print('Resume training') checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) if args.test_only: evaluate(model, data_loader_test, device=device) return # Training print('Start training') start_time = time.time() for epoch in range(args.epochs): train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() if args.output_dir: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'args': args }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) # evaluate after every epoch evaluate(model, data_loader_test, device=device) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def train_one_epoch( model, arch, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, ngpus_per_node, model_without_ddp, args ): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) # header = "Epoch: [{}]".format(epoch) for images, targets in metric_logger.log_every( iterable=data_loader, print_freq=print_freq, # header=header, iter_num=args.iter_num ): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] """ [{"boxes": tensor([], device="cuda:0"), "labels": tensor([], device="cuda:0", dtype=torch.int64), "masks": tensor([], device="cuda:0", dtype=torch.uint8), "iscrowd": tensor([], device="cuda:0", dtype=torch.int64)}] """ try: loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss_value = losses_reduced.item() if not math.isfinite(loss_value): logger.fatal("Loss is {}, stopping training".format(loss_value)) logger.fatal(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() optimizer.step() lr_scheduler.step() metric_logger.update(loss=losses_reduced, **loss_dict_reduced) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) except Exception as e: logger.warning(e, exc_info=True) # logger.info("print target for debug") # print(targets) args.iter_num += 1 # save checkpoint here if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): if args.iter_num % 1000 == 0: utils.save_on_master({ "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "iter_num": args.iter_num, "args": args, }, "{}/{}_{}.pth".format(checkpoint_dir, arch.__name__, args.iter_num) ) os.makedirs("{}/debug_image/".format(checkpoint_dir), exist_ok=True) if args.iter_num < 5000: continue model.eval() from barez import overlay_ann debug_image = None debug_image_list = [] cnt = 0 for image_path in glob.glob("./table_test/*"): cnt += 1 image_name = os.path.basename(image_path) # print(image_name) image = cv2.imread(image_path) rat = 1300 / image.shape[0] image = cv2.resize(image, None, fx=rat, fy=rat) transform = transforms.Compose([transforms.ToTensor()]) image = transform(image) # put the model in evaluation mode with torch.no_grad(): tensor = [image.to(device)] prediction = model(tensor) image = torch.squeeze(image, 0).permute(1, 2, 0).mul(255).numpy().astype(np.uint8) for pred in prediction: for idx, mask in enumerate(pred['masks']): if pred['scores'][idx].item() < 0.5: continue m = mask[0].mul(255).byte().cpu().numpy() box = list(map(int, pred["boxes"][idx].tolist())) score = pred["scores"][idx].item() image = overlay_ann(image, m, box, "", score) if debug_image is None: debug_image = image else: debug_image = np.concatenate((debug_image, image), axis=1) if cnt == 10: cnt = 0 debug_image_list.append(debug_image) debug_image = None avg_length = np.mean([i.shape[1] for i in debug_image_list]) di = None for debug_image in debug_image_list: rat = avg_length / debug_image.shape[1] debug_image = cv2.resize(debug_image, None, fx=rat, fy=rat) if di is None: di = debug_image else: di = np.concatenate((di, debug_image), axis=0) di = cv2.resize(di, None, fx=0.4, fy=0.4) cv2.imwrite("{}/debug_image/{}.jpg".format(checkpoint_dir, args.iter_num), di) model.train() # hard stop if args.iter_num == 50000: logger.info("ITER NUM == 50k, training successfully!") raise SystemExit
def main(args): utils.init_distributed_mode(args) # disable any harsh augmentation in case of Self-supervise training if args.training_mode == 'SSL': print("NOTE: Smoothing, Mixup, CutMix, and AutoAugment will be disabled in case of Self-supervise training") args.smoothing = args.reprob = args.reprob = args.recount = args.mixup = args.cutmix = 0.0 args.aa = '' if args.SiT_LinearEvaluation == 1: print("Warning: Linear Evaluation should be set to 0 during SSL training - changing SiT_LinearEvaluation to 0") args.SiT_LinearEvaluation = 0 utils.print_args(args) device = torch.device(args.device) seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True print("Loading dataset ....") dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, collate_fn=collate_fn) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False, collate_fn=collate_fn) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, representation_size=args.representation_size, drop_block_rate=None, training_mode=args.training_mode) if args.finetune: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['rot_head.weight', 'rot_head.bias', 'contrastive_head.weight', 'contrastive_head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) new_size = int(num_patches ** 0.5) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) model.to(device) # Freeze the backbone in case of linear evaluation if args.SiT_LinearEvaluation == 1: requires_grad(model, False) model.rot_head.weight.requires_grad = True model.rot_head.bias.requires_grad = True model.contrastive_head.weight.requires_grad = True model.contrastive_head.bias.requires_grad = True if args.representation_size is not None: model.pre_logits_rot.fc.weight.requires_grad = True model.pre_logits_rot.fc.bias.requires_grad = True model.pre_logits_contrastive.fc.weight.requires_grad = True model.pre_logits_contrastive.fc.bias.requires_grad = True model_ema = None if args.model_ema: model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) if args.training_mode == 'SSL': criterion = MTL_loss(args.device, args.batch_size) elif args.training_mode == 'finetune' and args.mixup > 0.: criterion = SoftTargetCrossEntropy() else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate_SSL(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) if args.training_mode == 'SSL': train_stats = train_SSL( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) else: train_stats = train_finetune( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) lr_scheduler.step(epoch) if epoch%args.validate_every == 0: if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) if args.training_mode == 'SSL': test_stats = evaluate_SSL(data_loader_val, model, device, epoch, args.output_dir) else: test_stats = evaluate_finetune(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.prototype and prototype is None: raise ImportError( "The prototype module couldn't be found. Please install the latest torchvision nightly." ) if not args.prototype and args.weights: raise ValueError( "The weights parameter works only in prototype mode. Please pass the --prototype argument." ) if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) print("torch version: ", torch.__version__) print("torchvision version: ", torchvision.__version__) device = torch.device(args.device) torch.backends.cudnn.benchmark = True # Data loading code print("Loading data") traindir = os.path.join(args.data_path, args.train_dir) valdir = os.path.join(args.data_path, args.val_dir) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") dataset, _ = torch.load(cache_path) dataset.transform = transform_train else: if args.distributed: print( "It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster" ) dataset = torchvision.datasets.Kinetics400( traindir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_train, frame_rate=15, extensions=( "avi", "mp4", ), ) if args.cache_dataset: print(f"Saving dataset_train to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) if not args.prototype: transform_test = presets.VideoClassificationPresetEval( resize_size=(128, 171), crop_size=(112, 112)) else: if args.weights: weights = prototype.models.get_weight(args.weights) transform_test = weights.transforms() else: transform_test = prototype.transforms.Kinect400Eval( crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) dataset_test.transform = transform_test else: if args.distributed: print( "It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster" ) dataset_test = torchvision.datasets.Kinetics400( valdir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_test, frame_rate=15, extensions=( "avi", "mp4", ), ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video) test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video) if args.distributed: train_sampler = DistributedSampler(train_sampler) test_sampler = DistributedSampler(test_sampler) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, collate_fn=collate_fn, ) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True, collate_fn=collate_fn, ) print("Creating model") if not args.prototype: model = torchvision.models.video.__dict__[args.model]( pretrained=args.pretrained) else: model = prototype.models.video.__dict__[args.model]( weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() lr = args.lr * args.world_size optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) scaler = torch.cuda.amp.GradScaler() if args.amp else None # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs iters_per_epoch = len(data_loader) lr_milestones = [ iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones ] main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=lr_milestones, gamma=args.lr_gamma) if args.lr_warmup_epochs > 0: warmup_iters = iters_per_epoch * args.lr_warmup_epochs args.lr_warmup_method = args.lr_warmup_method.lower() if args.lr_warmup_method == "linear": warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters) elif args.lr_warmup_method == "constant": warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters) else: raise RuntimeError( f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported." ) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]) else: lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module if args.resume: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if args.amp: scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "args": args, } if args.amp: checkpoint["scaler"] = scaler.state_dict() utils.save_on_master( checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master( checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f"Training time {total_time_str}")
def main(args): if args.output_dir: utils.mkdir(args.output_dir) print(args) print("torch version: ", torch.__version__) print("torchvision version: ", torchvision.__version__) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device('cuda') # Data loading code print("Loading data") st = time.time() dataset = MSRAction3D(root=args.data_path, frames_per_clip=args.clip_len, step_between_clips=1, num_points=args.num_points, train=True) dataset_test = MSRAction3D(root=args.data_path, frames_per_clip=args.clip_len, step_between_clips=1, num_points=args.num_points, train=False) print("Creating data loaders") data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) print("Creating model") Model = getattr(Models, args.model) model = Model(radius=args.radius, nsamples=args.nsamples, num_classes=dataset.num_classes) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.to(device) criterion = nn.CrossEntropyLoss() lr = args.lr optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs warmup_iters = args.lr_warmup_epochs * len(data_loader) lr_milestones = [len(data_loader) * m for m in args.lr_milestones] lr_scheduler = utils.WarmupMultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma, warmup_iters=warmup_iters, warmup_factor=1e-5) model_without_ddp = model if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 print("Start training") start_time = time.time() acc = 0 for epoch in range(args.start_epoch, args.epochs): train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq) acc = max(acc, evaluate(model, criterion, data_loader_test, device=device)) if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) print('Accuracy {}'.format(acc))
def main(args): if args.apex and amp is None: raise RuntimeError( "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) torch.backends.cudnn.benchmark = True train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data( train_dir, val_dir, args) collate_fn = None num_classes = len(dataset.classes) mixup_transforms = [] if args.mixup_alpha > 0.0: mixup_transforms.append( transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) if args.cutmix_alpha > 0.0: mixup_transforms.append( transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) if mixup_transforms: mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) collate_fn = lambda batch: mixupcutmix(*default_collate(batch) ) # noqa: E731 data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, collate_fn=collate_fn) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) opt_name = args.opt.lower() if opt_name.startswith("sgd"): optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov="nesterov" in opt_name) elif opt_name == 'rmsprop': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) else: raise RuntimeError( "Invalid optimizer {}. Only SGD and RMSprop are supported.".format( args.opt)) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == 'steplr': main_lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) elif args.lr_scheduler == 'cosineannealinglr': main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs - args.lr_warmup_epochs) elif args.lr_scheduler == 'exponentiallr': main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=args.lr_gamma) else: raise RuntimeError( "Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR " "are supported.".format(args.lr_scheduler)) if args.lr_warmup_epochs > 0: if args.lr_warmup_method == 'linear': warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs) elif args.lr_warmup_method == 'constant': warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs) else: raise RuntimeError( f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " "are supported.") lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]) else: lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module model_ema = None if args.model_ema: model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if model_ema: model_ema.load_state_dict(checkpoint['model_ema']) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA') if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } if model_ema: checkpoint['model_ema'] = model_ema.state_dict() utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def load_data(traindir, valdir, args): # Data loading code print("Loading data") resize_size, crop_size = 256, 224 interpolation = InterpolationMode.BILINEAR if args.model == 'inception_v3': resize_size, crop_size = 342, 299 elif args.model.startswith('efficientnet_'): sizes = { 'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300), 'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600), } e_type = args.model.replace('efficientnet_', '') resize_size, crop_size = sizes[e_type] interpolation = InterpolationMode.BICUBIC print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob)) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) else: dataset_test = torchvision.datasets.ImageFolder( valdir, presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation)) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
def main(args, layer_train_para, layer_names, layer_kernel_inc, pattern): if args.apex: if sys.version_info < (3, 0): raise RuntimeError( "Apex currently only supports Python 3. Aborting.") if amp is None: raise RuntimeError( "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) torch.backends.cudnn.benchmark = True train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data( train_dir, val_dir, args.cache_dataset, args.distributed) data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) # layer_train_para = [ # "layer1.0.conv1.weight", # "layer1.0.bn1.weight", # "layer1.0.bn1.bias", # "layer1.0.conv2.weight", # "layer1.0.bn2.weight", # "layer1.0.bn2.bias", # "layer1.1.conv1.weight", # "layer1.1.bn1.weight", # "layer1.1.bn1.bias", # "layer1.1.conv2.weight", # "layer1.1.bn2.weight", # "layer1.1.bn2.bias", # "layer2.0.conv2.weight", # "layer2.0.bn2.weight", # "layer2.0.bn2.bias", # "layer2.0.conv1.weight", # "layer2.0.bn1.weight", # "layer2.0.bn1.bias", # "layer2.0.downsample.0.weight", # "layer2.0.downsample.1.weight", # "layer2.0.downsample.1.bias"] # # layer_names = [ # "layer1.0.conv1", # "layer1.0.conv2", # "layer1.1.conv1", # "layer1.1.conv2", # "layer2.0.conv2", # "layer2.1.conv1", # "layer2.1.conv2" # ] # # layer_kernel_inc = [ # # "layer2.0.conv1", # # "layer2.0.downsample.0" # ] # # pattern = {} # pattern[0] = torch.tensor([[0, 0, 0], # [1, 1, 1], # [1, 1, 1]], dtype=torch.float32) # # pattern[1] = torch.tensor([[1, 1, 1], # [1, 1, 1], # [0, 0, 0]], dtype=torch.float32) # # pattern[2] = torch.tensor([[1, 1, 0], # [1, 1, 0], # [1, 1, 0]], dtype=torch.float32) # # pattern[3] = torch.tensor([[0, 1, 1], # [0, 1, 1], # [0, 1, 1]], dtype=torch.float32) layers = {} ki_layers = {} # for layer_name, layer in model.named_modules(): for layer_name, layer in model.named_modules(): if isinstance(layer, nn.Conv2d): # if is_same(layer.kernel_size) == 3 and layer.in_channels == 512: # if is_same(layer.kernel_size) == 3: if layer_name in layer_names: # layer_names.append(layer_name) layers[layer_name] = layer if layer_name in layer_kernel_inc: ki_layers[layer_name] = layer # print(layer_name) # if is_same(layer.kernel_size) == 3 and layer.in_channels==512: # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # mask = torch.tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]], dtype=torch.float32, device=device) # ztNAS_add_kernel_mask(model, layer, layer_name, mask=mask) #model = modify_model(model) # for name, param in model.named_parameters(): # names = [n + "." for n in name.split(".")[:-1]] # if "".join(names)[:-1] not in layer_names: # param.requires_grad = False # else: # break for name, param in model.named_parameters(): if name in layer_train_para: param.requires_grad = True else: param.requires_grad = False # for name, param in model.named_parameters(): # print(name, param.requires_grad, param.data.shape) # print(model) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) admm_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=args.adam_epsilon) admm_re_train_optimizer = PruneAdam(model.named_parameters(), lr=args.lr, eps=args.adam_epsilon) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: # for name, param in model.named_parameters(): # print(name)git oull # print(param) layer_pattern = utils.get_layers_pattern(model, layer_names, pattern, device) utils.print_prune(model, layer_names, layer_pattern) for layer_name in layer_names: ztNAS_add_kernel_mask(model, layers[layer_name], layer_name, is_pattern=True, pattern=layer_pattern[layer_name].to(device)) # print(model) model.to(device) evaluate(model, criterion, data_loader_test, device=device) # evaluate(model, criterion, data_loader_test, device=device) return if args.retrain_only: epoch = 999 print("Start re-training") start_time = time.time() print("=" * 10, "Applying pruning model") layer_pattern = utils.get_layers_pattern(model, layer_names, pattern, device) # utils.print_prune(model, layer_names, layer_pattern) for layer_name in layer_names: ztNAS_add_kernel_mask(model, layers[layer_name], layer_name, is_pattern=True, pattern=layer_pattern[layer_name].to(device)) for layer_name in layer_kernel_inc: ztNAS_modify_kernel_shape(model, ki_layers[layer_name], layer_name, 2) # print(model) model.to(device) # evaluate(model, criterion, data_loader_test, device=device) print("=" * 10, "Retrain") re_train_one_epoch(model, criterion, admm_re_train_optimizer, data_loader, device, epoch, args.print_freq, layer_names, layer_pattern, data_loader_test, args.exploration, args.apex) acc1, acc5 = evaluate(model, criterion, data_loader_test, device=device, exploration=args.exploration) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) return acc1, acc5 print("Start training") start_time = time.time() Z, U = utils.initialize_Z_and_U(model, layer_names) rho = args.rho for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) Z, U = train_one_epoch(model, criterion, admm_optimizer, data_loader, device, epoch, args.print_freq, layer_names, percent, pattern, Z, U, rho, args.apex) rho = rho * 10 lr_scheduler.step() if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) evaluate(model, criterion, data_loader_test, device=device) print("=" * 10, "Applying pruning model") layer_pattern = utils.get_layers_pattern(model, layer_names, pattern, device) # utils.print_prune(model, layer_names, layer_pattern) for layer_name in layer_names: ztNAS_add_kernel_mask(model, layers[layer_name], layer_name, is_pattern=True, pattern=layer_pattern[layer_name].to(device)) # print(model) model.to(device) # evaluate(model, criterion, data_loader_test, device=device) print("=" * 10, "Retrain") re_train_one_epoch(model, criterion, admm_re_train_optimizer, data_loader, device, epoch, args.print_freq, layer_names, layer_pattern, data_loader_test, args.exploration, args.apex) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch + 1, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master(checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time)))
def main(args): utils.init_distributed_mode(args) update_config_from_file(args.cfg) print(args) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) sampler_train = torch.utils.data.RandomSampler(dataset_train) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=int( 2 * args.batch_size), sampler=sampler_val, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating SuperVisionTransformer") print(cfg) model = Vision_TransformerSuper( img_size=args.input_size, patch_size=args.patch_size, embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH, num_heads=cfg.SUPERNET.NUM_HEADS, mlp_ratio=cfg.SUPERNET.MLP_RATIO, qkv_bias=True, drop_rate=args.drop, drop_path_rate=args.drop_path, gp=args.gp, num_classes=args.nb_classes, max_relative_position=args.max_relative_position, relative_position=args.relative_position, change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos) choices = { 'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO, 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM, 'depth': cfg.SEARCH_SPACE.DEPTH } model.to(device) if args.teacher_model: teacher_model = create_model( args.teacher_model, pretrained=True, num_classes=args.nb_classes, ) teacher_model.to(device) teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: teacher_model = None teacher_loss = None model_ema = None model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) # criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True) # save config for later experiments with open(output_dir / "config.yaml", 'w') as f: f.write(args_text) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) retrain_config = None if args.mode == 'retrain' and "RETRAIN" in cfg: retrain_config = { 'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM] * cfg.RETRAIN.DEPTH, 'num_heads': cfg.RETRAIN.NUM_HEADS, 'mlp_ratio': cfg.RETRAIN.MLP_RATIO } if args.eval: print(retrain_config) test_stats = evaluate(data_loader_val, model, device, mode=args.mode, retrain_config=retrain_config) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print("Start training") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, amp=args.amp, teacher_model=teacher_model, teach_loss=teacher_loss, choices=choices, mode=args.mode, retrain_config=retrain_config, ) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, # 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device, amp=args.amp, choices=choices, mode=args.mode, retrain_config=retrain_config) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.apex: if sys.version_info < (3, 0): raise RuntimeError( "Apex currently only supports Python 3. Aborting.") if amp is None: raise RuntimeError( "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) torch.backends.cudnn.benchmark = True train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data( train_dir, val_dir, args.cache_dataset, args.distributed) data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.is_hmp: from hmp import hmp hmp.convert(opt_level=args.hmp_opt_level, bf16_file_path=args.hmp_bf16, fp32_file_path=args.hmp_fp32, isVerbose=args.hmp_verbose) if args.apex: if sys.version_info < (3, 0): raise RuntimeError( "Apex currently only supports Python 3. Aborting.") if amp is None: raise RuntimeError( "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) if args.device == 'habana': sys.path.append( os.path.realpath( os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../common"))) from library_loader import load_habana_module load_habana_module() torch.manual_seed(args.seed) if args.deterministic: seed = args.seed if args.device == 'cuda': torch.cuda.manual_seed(seed) else: seed = None device = torch.device(args.device) torch.backends.cudnn.benchmark = True # Limit the test(eval) phase batch size to a lower value to reduce overall device memory pressure test_batch_size = args.batch_size if args.batch_size > 32: test_batch_size = 32 train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data( train_dir, val_dir, args.cache_dataset, args.distributed) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, worker_init_fn=dl_worker_init_fn(seed), pin_memory=True, drop_last=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=test_batch_size, sampler=test_sampler, num_workers=args.workers, worker_init_fn=dl_worker_init_fn(seed), pin_memory=True, drop_last=True) print("Creating model") #model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) #Instead of importing resnet model from the standard torchvision package, #import from a local copy. A local copy of resnet model file is used so that #modifications can be done to the resnet model if necessary. model = resnet_models.__dict__[args.model](pretrained=args.pretrained) model.to(device) if args.channels_last: if (device == torch.device('cuda')): print('Converting model to channels_last format on CUDA') model.to(memory_format=torch.channels_last) elif (args.device == 'habana'): print('Converting model params to channels_last format on Habana') #TODO: #model.to(device).to(memory_format=torch.channels_last) #The above model conversion doesn't change the model params #to channels_last for many components - e.g. convolution. #So we are forced to rearrange such tensors ourselves. if (args.device == 'habana'): permute_params(model, True) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) model_for_eval = model if args.run_trace_mode: sample_trace_tensor = enable_tracing(device) if args.channels_last: sample_trace_tensor = sample_trace_tensor.contiguous( memory_format=torch.channels_last) # Create traced model for eval model.eval() model_for_eval = torch.jit.trace(model, sample_trace_tensor, check_trace=False) # Create traced model for train model.train() model = torch.jit.trace(model, sample_trace_tensor, check_trace=False) model_for_train = model # TBD: pass the right module for ddp model_without_ddp = model if args.distributed: if args.device == 'habana': model = torch.nn.parallel.DistributedDataParallel( model, bucket_cap_mb=100, broadcast_buffers=False) else: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module model_for_train = model if args.resume: if (args.device == 'habana'): permute_params(model_without_ddp, False) checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if (args.device == 'habana'): permute_momentum(optimizer, True) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if (args.device == 'habana'): permute_params(model_without_ddp, True) if args.test_only: evaluate(model_for_eval, criterion, data_loader_test, device=device, print_freq=args.print_freq) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model_for_train, criterion, optimizer, data_loader, device, epoch, print_freq=args.print_freq, apex=args.apex) lr_scheduler.step() evaluate(model_for_eval, criterion, data_loader_test, device=device, print_freq=args.print_freq) if (args.output_dir and args.save_checkpoint): if args.device == 'habana': permute_params(model_without_ddp, False) #Use this model only to copy the state_dict of the actual model copy_model = resnet_models.__dict__[args.model]( pretrained=args.pretrained) copy_model.load_state_dict(model_without_ddp.state_dict()) permute_momentum(optimizer, False) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to('cpu') checkpoint = { 'model': copy_model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to('habana') permute_params(model_without_ddp, True) else: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): args.log_dir = save_path_formatter(args) if args.deconv: args.deconv = partial(FastDeconv, bias=args.bias, eps=args.eps, n_iter=args.deconv_iter, block=args.block, sampling_stride=args.stride) if args.tensorboard: from torch.utils.tensorboard import SummaryWriter args.writer = SummaryWriter(args.log_dir, flush_secs=30) if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) transform = get_transform(mode='train', base_size=args.base_size) dataset, num_classes = get_dataset(args.dataset, "train", transform=transform) transform = get_transform(mode='test', base_size=args.base_size) dataset_test, _ = get_dataset(args.dataset, "val", transform=transform) if args.dataset == 'cityscapes': args.colormap = np.asarray([ dataset.classes[i].color for i in range(max(dataset.new_classes) + 1) ]) else: args.colormap = create_mapillary_vistas_label_colormap() if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, collate_fn=utils.collate_fn, drop_last=True) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) #model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes,aux_loss=args.aux_loss,pretrained=args.pretrained) model = models.segmentation.__dict__[args.model]( num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained, deconv=args.deconv, pretrained_backbone=args.pretrained_backbone) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) args.start_epoch = checkpoint['epoch'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module elif args.device == 'cuda': model = torch.nn.DataParallel(model).cuda() if args.test_only: confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) return params_to_optimize = [ { "params": [ p for p in model_without_ddp.backbone.parameters() if p.requires_grad ] }, { "params": [ p for p in model_without_ddp.classifier.parameters() if p.requires_grad ] }, ] if args.aux_loss: params = [ p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad ] params_to_optimize.append({"params": params, "lr": args.lr * 10}) optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs))**0.9) if args.resume: total_steps = len(data_loader) * args.start_epoch global n_iter for i in range(total_steps): n_iter = n_iter + 1 lr_scheduler.step() start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) if epoch == 0 or (epoch + 1) % args.eval_freq == 0: confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, #'args': args }, #os.path.join(args.log_dir, 'model_{}.pth'.format(epoch))) os.path.join(args.log_dir, 'model.pth')) print(confmat) acc_global, acc, iu = confmat.compute() acc_global = acc_global.item() * 100 iu = iu.mean().item() * 100 if args.tensorboard: args.writer.add_scalar('Acc/Test', acc_global, epoch + 1) args.writer.add_scalar('IOU/Test', iu, epoch + 1) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) args.writer.close()
def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # Data loading code print("Loading data") dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path) dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) if args.aspect_ratio_group_factor >= 0: group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) else: train_batch_sampler = torch.utils.data.BatchSampler( train_sampler, args.batch_size, drop_last=True) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) print("Creating model") kwargs = {} if "keypoint" in args.model: kwargs["num_keypoints"] = 6 # if "rcnn" in args.model: # kwargs["rpn_score_thresh"] = 0.0 model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, **kwargs) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD( params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: evaluate(model, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() if args.output_dir: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'args': args, 'epoch': epoch}, os.path.join(args.output_dir, 'model77.pth')) # evaluate after every epoch evaluate(model, data_loader_test, device=device) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.apex: if sys.version_info < (3, 0): raise RuntimeError( "Apex currently only supports Python 3. Aborting.") if amp is None: raise RuntimeError( "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training.") if args.output_dir: utils.mkdir(args.output_dir) vis = utils.Visualize(args) utils.init_distributed_mode(args) print(args) print("torch version: ", torch.__version__) print("torchvision version: ", torchvision.__version__) device = torch.device(args.device) torch.backends.cudnn.benchmark = True # Data loading code print("Loading data") traindir = os.path.join( args.data_path, 'train_256' if not args.fast_test else 'val_256_bob') valdir = os.path.join(args.data_path, 'val_256_bob') normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) frame_transform_train = utils.make_frame_transform(args.frame_transforms) transform_train = torchvision.transforms.Compose([ # torchvision.transforms.RandomGrayscale(p=1), frame_transform_train, T.ToFloatTensorInZeroOne(), T.Resize((256, 256)), # T.Resize((128, 171)), # T.RandomHorizontalFlip(), # T.GaussianBlurTransform(), normalize, # T.RandomCrop((112, 112)) ]) def make_dataset(is_train): _transform = transform_train if is_train else transform_test if 'kinetics' in args.data_path.lower(): return Kinetics400(traindir if is_train else valdir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_train, extensions=('mp4'), frame_rate=args.frame_skip) else: return VideoList( args, is_train, frame_gap=args.frame_skip, transform=_transform, # frame_transform=_frame_transform ) if args.cache_dataset and os.path.exists(cache_path): print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) dataset.transform = transform_train else: if args.distributed: print("It is recommended to pre-compute the dataset cache " "on a single-gpu first, as it will be faster") dataset = make_dataset(is_train=True) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) if hasattr(dataset, 'video_clips'): dataset.video_clips.compute_clips(args.clip_len, 1, frame_rate=15) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) transform_test = torchvision.transforms.Compose([ T.ToFloatTensorInZeroOne(), # T.Resize((128, 171)), # normalize, # T.CenterCrop((112, 112)) T.Resize((256, 256)), normalize ]) if args.cache_dataset and os.path.exists(cache_path): print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) dataset_test.transform = transform_test else: if args.distributed: print("It is recommended to pre-compute the dataset cache " "on a single-gpu first, as it will be faster") # dataset_test = Kinetics400( # valdir, # frames_per_clip=args.clip_len, # step_between_clips=1, # transform=transform_test, # extensions=('mp4') # ) dataset_test = make_dataset(is_train=False) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) if hasattr(dataset, 'video_clips'): dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15) def make_data_sampler(is_train, dataset): if hasattr(dataset, 'video_clips'): _sampler = RandomClipSampler if is_train else UniformClipSampler return _sampler(dataset.video_clips, args.clips_per_video) else: return torch.utils.data.sampler.RandomSampler( dataset) if is_train else None print("Creating data loaders") train_sampler, test_sampler = make_data_sampler(True, dataset), \ make_data_sampler(False, dataset_test) # train_sampler = train_sampler(dataset.video_clips, args.clips_per_video) # test_sampler = test_sampler(dataset_test.video_clips, args.clips_per_video) if args.distributed: train_sampler = DistributedSampler(train_sampler) test_sampler = DistributedSampler(test_sampler) data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, collate_fn=collate_fn) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True, collate_fn=collate_fn) print("Creating model") import resnet import timecycle as tc # model = resnet.__dict__[args.model](pretrained=args.pretrained) model = tc.TimeCycle(args) # utils.compute_RF_numerical(model.resnet, torch.ones(1, 3, 1, 112, 112).numpy()) # import pdb; pdb.set_trace() # print(utils.compute_RF_numerical(model,img_np)) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() lr = args.lr * args.world_size # optimizer = torch.optim.SGD( # model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs warmup_iters = args.lr_warmup_epochs * len(data_loader) lr_milestones = [len(data_loader) * m for m in args.lr_milestones] lr_scheduler = WarmupMultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma, warmup_iters=warmup_iters, warmup_factor=1e-5) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module if args.data_parallel: model = torch.nn.parallel.DataParallel(model) model_without_ddp = model.module if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex, vis=vis) # evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.apex and amp is None: raise RuntimeError( "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) torch.backends.cudnn.benchmark = True train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data( train_dir, val_dir, args.cache_dataset, args.distributed) dataset.samples = [dataset.samples[idx] for idx in range(1024)] dataset.targets = [dataset.targets[idx] for idx in range(1024)] dataset_test.samples = [dataset.samples[idx] for idx in range(1024)] dataset_test.targets = [dataset.targets[idx] for idx in range(1024)] data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) model.to(device) # Mehrdad: Fuse from DG_Prune.FuseHook import Fuse_Hook, get_modules_to_fuse # from torch.quantization.fuse_modules import fuse_modules # model_fused = fuse_modules(model, h.modules_to_fuse, inplace=False, fuser_func=modified_fuse_known_modules ) # Mehrdad: Prune from DG_Prune import DG_Pruner, TaylorImportance, MagnitudeImportance, RigLImportance, PrunableConv2d dgPruner = None if args.prune: dgPruner = DG_Pruner() model = dgPruner.swap_prunable_modules(model) # dgPruner.dump_sparsity_stat(model, output_dir, 0) pruners = dgPruner.pruners_from_file('DG_Prune/rigl_resnet50.json') hooks = dgPruner.add_custom_pruning(model, RigLImportance) # fuse_type_list = [[PrunableConv2d, nn.BatchNorm2d]] for image, _ in data_loader_test: sample_image = image[0].unsqueeze(0) break modules_to_fuse = get_modules_to_fuse(model, fuse_type_list, sample_image) dgPruner.attach_bn_to_prunables(model, modules_to_fuse) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * ( 1 - args.lrf) + args.lrf # cosine lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: evaluate(model, criterion, data_loader_test, device=device, print_freq=args.print_freq, dgPruner=dgPruner, output_dir=args.output_dir) return print("Start training") start_time = time.time() for lth_stage in range(0, dgPruner.num_stages() + 1): if (lth_stage != 0): checkpoint = dgPruner.rewind_masked_checkpoint() model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 dgPruner.dump_sparsity_stat(model, args.output_dir, lth_stage * 10000) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_metrics = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, dgPruner=dgPruner, output_dir=args.output_dir) lr_scheduler.step() eval_metrics = evaluate(model, criterion, data_loader_test, device=device, print_freq=args.print_freq, dgPruner=dgPruner, output_dir=args.output_dir) if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) # Mehrdad: LTH, pruning in the end if (args.prune): if (epoch == args.epochs - 1): dgPruner.prune_n_reset(epoch) dgPruner.dump_sparsity_stat(model, args.output_dir, epoch) dgPruner.apply_mask_to_weight() checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } # Save checkpoints if (lth_stage == 0) and (epoch == dgPruner.rewind_epoch( args.epochs)): dgPruner.save_rewind_checkpoint(checkpoint) if (epoch == args.epochs - 1): dgPruner.save_final_checkpoint(checkpoint) update_summary(epoch, train_metrics, eval_metrics, os.path.join(args.output_dir, 'summary.csv')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) print(args) if args.distillation_type != 'none' and args.finetune and not args.eval: raise NotImplementedError( "Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int( 1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in [ 'head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias' ]: if k in checkpoint_model and checkpoint_model[ k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int( (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) # height (== width) for the new position embedding new_size = int(num_patches**0.5) # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() teacher_model = None if args.distillation_type != 'none': assert args.teacher_path, 'need to specify teacher-path when using distillation' print(f"Creating teacher model: {args.teacher_model}") teacher_model = create_model( args.teacher_model, pretrained=False, num_classes=args.nb_classes, global_pool='avg', ) if args.teacher_path.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.teacher_path, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.teacher_path, map_location='cpu') teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval() # wrap the criterion in our custom DistillationLoss, which # just dispatches to the original criterion if args.distillation_type is 'none' criterion = DistillationLoss(criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, set_training_mode=args.finetune == '' # keep in eval mode during finetuning ) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / ('checkpoint_%04d.pth' % (epoch))] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) if not args.train_without_eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } else: log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) if args.post_training_quantize and args.distributed: raise RuntimeError("Post training quantization example should not be performed " "on distributed mode") # Set backend engine to ensure that quantized model runs on the correct kernels if args.backend not in torch.backends.quantized.supported_engines: raise RuntimeError("Quantized backend not supported: " + str(args.backend)) torch.backends.quantized.engine = args.backend device = torch.device(args.device) torch.backends.cudnn.benchmark = True # Data loading code print("Loading data") train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args.cache_dataset, args.distributed) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) model.to(device) if not (args.test_only or args.post_training_quantize): model.fuse_model() model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend) torch.quantization.prepare_qat(model, inplace=True) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = torch.optim.SGD( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) criterion = nn.CrossEntropyLoss() model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.post_training_quantize: # perform calibration on a subset of the training dataset # for that, create a subset of the training dataset ds = torch.utils.data.Subset( dataset, indices=list(range(args.batch_size * args.num_calibration_batches))) data_loader_calibration = torch.utils.data.DataLoader( ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) model.eval() model.fuse_model() model.qconfig = torch.quantization.get_default_qconfig(args.backend) torch.quantization.prepare(model, inplace=True) # Calibrate first print("Calibrating") evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) torch.quantization.convert(model, inplace=True) if args.output_dir: print('Saving quantized model') if utils.is_main_process(): torch.save(model.state_dict(), os.path.join(args.output_dir, 'quantized_post_train_model.pth')) print("Evaluating post-training quantized model") evaluate(model, criterion, data_loader_test, device=device) return if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return model.apply(torch.quantization.enable_observer) model.apply(torch.quantization.enable_fake_quant) start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) print('Starting training for epoch', epoch) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() with torch.no_grad(): if epoch >= args.num_observer_update_epochs: print('Disabling observer for subseq epochs, epoch = ', epoch) model.apply(torch.quantization.disable_observer) if epoch >= args.num_batch_norm_update_epochs: print('Freezing BN for subseq epochs, epoch = ', epoch) model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) print('Evaluate QAT model') evaluate(model, criterion, data_loader_test, device=device) quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() quantized_eval_model.to(torch.device('cpu')) torch.quantization.convert(quantized_eval_model, inplace=True) print('Evaluate Quantized model') evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device('cpu')) model.train() if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'eval_model': quantized_eval_model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args} utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) print('Saving models after epoch ', epoch) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) torch.backends.cudnn.benchmark = True # Data loading code print("Loading data") traindir = os.path.join(args.data_path, 'train') valdir = os.path.join(args.data_path, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: dataset = torchvision.datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) else: dataset_test = torchvision.datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args } utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # Data loading code print("Loading data") # 支持加载自定义Pascal格式数据集 参数dataset设置为custom_voc if args.dataset == 'custom_voc': # dataset, num_classes = get_custom_voc(args.train_data_path,get_transform(train=True)) # dataset_test, _ = get_custom_voc(args.test_data_path,get_transform(train=False)) # 如果是自定义Pascal数据集,不需要传入image_set参数,因此这里设置为None dataset, num_classes = get_dataset(args.dataset, None, get_transform(train=True), args.train_data_path) dataset_test, _ = get_dataset(args.dataset, None, get_transform(train=False), args.test_data_path) else: dataset, num_classes = get_dataset( args.dataset, "train" if args.dataset == 'coco' else 'trainval', get_transform(train=True), args.data_path) dataset_test, _ = get_dataset( args.dataset, "test" if args.dataset == 'coco' else 'val', get_transform(train=False), args.data_path) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) if args.aspect_ratio_group_factor >= 0: group_ids = create_aspect_ratio_groups( dataset, k=args.aspect_ratio_group_factor) train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) else: train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) print("Creating model") # model = torchvision.models.detection.fasterrcnn_resnet50_fpn() model = torchvision.models.detection.__dict__[args.model]( num_classes=num_classes, pretrained=args.pretrained) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict( checkpoint['optimizer']) # 用于恢复训练,处理模型还需要优化器和学习率规则 lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) # 如果只进行模型测试,注意这里传入的参数是--resume, 原作者只提到了--resume用于恢复训练,根据官方文档可知也是可以用于模型推理的 # 参考官方文档https://pytorch.org/tutorials/beginner/saving_loading_models.html if args.test_only: if not args.resume: raise Exception('需要checkpoints模型用于推理!') else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if 'coco' == args.dataset: coco_evaluate(model_without_ddp, data_loader_test, device=device) elif 'voc' == args.dataset: voc_evaluate(model_without_ddp, data_loader_test, device=device) elif 'custom_voc' == args.dataset: custom_voc_evaluate(model_without_ddp, data_loader_test, device=device) else: print( f'No evaluation method available for the dataset {args.dataset}' ) # evaluate(model, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() if args.output_dir: # model.save('./checkpoints/model_{}_{}.pth'.format(args.dataset, epoch)) utils.save_on_master( { 'model': model_without_ddp.state_dict(), # 存储网络参数(不存储网络骨架) # 'model': model_without_ddp, # 存储整个网络 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'args': args }, os.path.join(args.output_dir, 'model_{}_{}.pth'.format(args.dataset, epoch))) # evaluate after every epoch if args.dataset == 'coco': coco_evaluate(model, data_loader_test, device=device) elif 'voc' == args.dataset: voc_evaluate(model, data_loader_test, device=device) elif 'custom_voc' == args.dataset: custom_voc_evaluate(model, data_loader_test, device=device) else: print( f'No evaluation method available for the dataset {args.dataset}' ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))