def main(args): utils.init_distributed_mode(args) print(args) if args.distillation_type != 'none' and args.finetune and not args.eval: raise NotImplementedError( "Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int( 1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in [ 'head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias' ]: if k in checkpoint_model and checkpoint_model[ k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int( (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) # height (== width) for the new position embedding new_size = int(num_patches**0.5) # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() teacher_model = None if args.distillation_type != 'none': assert args.teacher_path, 'need to specify teacher-path when using distillation' print(f"Creating teacher model: {args.teacher_model}") teacher_model = create_model( args.teacher_model, pretrained=False, num_classes=args.nb_classes, global_pool='avg', ) if args.teacher_path.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.teacher_path, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.teacher_path, map_location='cpu') teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval() # wrap the criterion in our custom DistillationLoss, which # just dispatches to the original criterion if args.distillation_type is 'none' criterion = DistillationLoss(criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, set_training_mode=args.finetune == '' # keep in eval mode during finetuning ) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / ('checkpoint_%04d.pth' % (epoch))] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) if not args.train_without_eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } else: log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) # disable any harsh augmentation in case of Self-supervise training if args.training_mode == 'SSL': print("NOTE: Smoothing, Mixup, CutMix, and AutoAugment will be disabled in case of Self-supervise training") args.smoothing = args.reprob = args.reprob = args.recount = args.mixup = args.cutmix = 0.0 args.aa = '' if args.SiT_LinearEvaluation == 1: print("Warning: Linear Evaluation should be set to 0 during SSL training - changing SiT_LinearEvaluation to 0") args.SiT_LinearEvaluation = 0 utils.print_args(args) device = torch.device(args.device) seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True print("Loading dataset ....") dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, collate_fn=collate_fn) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False, collate_fn=collate_fn) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, representation_size=args.representation_size, drop_block_rate=None, training_mode=args.training_mode) if args.finetune: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['rot_head.weight', 'rot_head.bias', 'contrastive_head.weight', 'contrastive_head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) new_size = int(num_patches ** 0.5) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) model.to(device) # Freeze the backbone in case of linear evaluation if args.SiT_LinearEvaluation == 1: requires_grad(model, False) model.rot_head.weight.requires_grad = True model.rot_head.bias.requires_grad = True model.contrastive_head.weight.requires_grad = True model.contrastive_head.bias.requires_grad = True if args.representation_size is not None: model.pre_logits_rot.fc.weight.requires_grad = True model.pre_logits_rot.fc.bias.requires_grad = True model.pre_logits_contrastive.fc.weight.requires_grad = True model.pre_logits_contrastive.fc.bias.requires_grad = True model_ema = None if args.model_ema: model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) if args.training_mode == 'SSL': criterion = MTL_loss(args.device, args.batch_size) elif args.training_mode == 'finetune' and args.mixup > 0.: criterion = SoftTargetCrossEntropy() else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate_SSL(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) if args.training_mode == 'SSL': train_stats = train_SSL( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) else: train_stats = train_finetune( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) lr_scheduler.step(epoch) if epoch%args.validate_every == 0: if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) if args.training_mode == 'SSL': test_stats = evaluate_SSL(data_loader_val, model, device, epoch, args.output_dir) else: test_stats = evaluate_finetune(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) update_config_from_file(args.cfg) print(args) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) sampler_train = torch.utils.data.RandomSampler(dataset_train) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=int( 2 * args.batch_size), sampler=sampler_val, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating SuperVisionTransformer") print(cfg) model = Vision_TransformerSuper( img_size=args.input_size, patch_size=args.patch_size, embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH, num_heads=cfg.SUPERNET.NUM_HEADS, mlp_ratio=cfg.SUPERNET.MLP_RATIO, qkv_bias=True, drop_rate=args.drop, drop_path_rate=args.drop_path, gp=args.gp, num_classes=args.nb_classes, max_relative_position=args.max_relative_position, relative_position=args.relative_position, change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos) choices = { 'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO, 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM, 'depth': cfg.SEARCH_SPACE.DEPTH } model.to(device) if args.teacher_model: teacher_model = create_model( args.teacher_model, pretrained=True, num_classes=args.nb_classes, ) teacher_model.to(device) teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: teacher_model = None teacher_loss = None model_ema = None model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) # criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True) # save config for later experiments with open(output_dir / "config.yaml", 'w') as f: f.write(args_text) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) retrain_config = None if args.mode == 'retrain' and "RETRAIN" in cfg: retrain_config = { 'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM] * cfg.RETRAIN.DEPTH, 'num_heads': cfg.RETRAIN.NUM_HEADS, 'mlp_ratio': cfg.RETRAIN.MLP_RATIO } if args.eval: print(retrain_config) test_stats = evaluate(data_loader_val, model, device, mode=args.mode, retrain_config=retrain_config) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print("Start training") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, amp=args.amp, teacher_model=teacher_model, teach_loss=teacher_loss, choices=choices, mode=args.mode, retrain_config=retrain_config, ) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, # 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device, amp=args.amp, choices=choices, mode=args.mode, retrain_config=retrain_config) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main( log_dir, dataset, im_size, crop_size, window_size, window_stride, backbone, decoder, optimizer, scheduler, weight_decay, dropout, drop_path, batch_size, epochs, learning_rate, normalization, eval_freq, amp, resume, ): # start distributed mode ptu.set_gpu_mode(True) distributed.init_process() # set up configuration cfg = config.load_config() model_cfg = cfg["model"][backbone] dataset_cfg = cfg["dataset"][dataset] if "mask_transformer" in decoder: decoder_cfg = cfg["decoder"]["mask_transformer"] else: decoder_cfg = cfg["decoder"][decoder] # model config if not im_size: im_size = dataset_cfg["im_size"] if not crop_size: crop_size = dataset_cfg.get("crop_size", im_size) if not window_size: window_size = dataset_cfg.get("window_size", im_size) if not window_stride: window_stride = dataset_cfg.get("window_stride", im_size) model_cfg["image_size"] = (crop_size, crop_size) model_cfg["backbone"] = backbone model_cfg["dropout"] = dropout model_cfg["drop_path_rate"] = drop_path decoder_cfg["name"] = decoder model_cfg["decoder"] = decoder_cfg # dataset config world_batch_size = dataset_cfg["batch_size"] num_epochs = dataset_cfg["epochs"] lr = dataset_cfg["learning_rate"] if batch_size: world_batch_size = batch_size if epochs: num_epochs = epochs if learning_rate: lr = learning_rate if eval_freq is None: eval_freq = dataset_cfg.get("eval_freq", 1) if normalization: model_cfg["normalization"] = normalization # experiment config batch_size = world_batch_size // ptu.world_size variant = dict( world_batch_size=world_batch_size, version="normal", resume=resume, dataset_kwargs=dict( dataset=dataset, image_size=im_size, crop_size=crop_size, batch_size=batch_size, normalization=model_cfg["normalization"], split="train", num_workers=10, ), algorithm_kwargs=dict( batch_size=batch_size, start_epoch=0, num_epochs=num_epochs, eval_freq=eval_freq, ), optimizer_kwargs=dict( opt=optimizer, lr=lr, weight_decay=weight_decay, momentum=0.9, clip_grad=None, sched=scheduler, epochs=num_epochs, min_lr=1e-5, poly_power=0.9, poly_step_size=1, ), net_kwargs=model_cfg, amp=amp, log_dir=log_dir, inference_kwargs=dict( im_size=im_size, window_size=window_size, window_stride=window_stride, ), ) log_dir = Path(log_dir) log_dir.mkdir(parents=True, exist_ok=True) checkpoint_path = log_dir / "checkpoint.pth" # dataset dataset_kwargs = variant["dataset_kwargs"] train_loader = create_dataset(dataset_kwargs) val_kwargs = dataset_kwargs.copy() val_kwargs["split"] = "val" val_kwargs["batch_size"] = 1 val_kwargs["crop"] = False val_loader = create_dataset(val_kwargs) n_cls = train_loader.unwrapped.n_cls # model net_kwargs = variant["net_kwargs"] net_kwargs["n_cls"] = n_cls model = create_segmenter(net_kwargs) model.to(ptu.device) # optimizer optimizer_kwargs = variant["optimizer_kwargs"] optimizer_kwargs["iter_max"] = len(train_loader) * optimizer_kwargs["epochs"] optimizer_kwargs["iter_warmup"] = 0.0 opt_args = argparse.Namespace() opt_vars = vars(opt_args) for k, v in optimizer_kwargs.items(): opt_vars[k] = v optimizer = create_optimizer(opt_args, model) lr_scheduler = create_scheduler(opt_args, optimizer) num_iterations = 0 amp_autocast = suppress loss_scaler = None if amp: amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() # resume if resume and checkpoint_path.exists(): print(f"Resuming training from checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) if loss_scaler and "loss_scaler" in checkpoint: loss_scaler.load_state_dict(checkpoint["loss_scaler"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) variant["algorithm_kwargs"]["start_epoch"] = checkpoint["epoch"] + 1 else: sync_model(log_dir, model) if ptu.distributed: model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True) # save config variant_str = yaml.dump(variant) print(f"Configuration:\n{variant_str}") variant["net_kwargs"] = net_kwargs variant["dataset_kwargs"] = dataset_kwargs log_dir.mkdir(parents=True, exist_ok=True) with open(log_dir / "variant.yml", "w") as f: f.write(variant_str) # train start_epoch = variant["algorithm_kwargs"]["start_epoch"] num_epochs = variant["algorithm_kwargs"]["num_epochs"] eval_freq = variant["algorithm_kwargs"]["eval_freq"] model_without_ddp = model if hasattr(model, "module"): model_without_ddp = model.module val_seg_gt = val_loader.dataset.get_gt_seg_maps() print(f"Train dataset length: {len(train_loader.dataset)}") print(f"Val dataset length: {len(val_loader.dataset)}") print(f"Encoder parameters: {num_params(model_without_ddp.encoder)}") print(f"Decoder parameters: {num_params(model_without_ddp.decoder)}") for epoch in range(start_epoch, num_epochs): # train for one epoch train_logger = train_one_epoch( model, train_loader, optimizer, lr_scheduler, epoch, amp_autocast, loss_scaler, ) # save checkpoint if ptu.dist_rank == 0: snapshot = dict( model=model_without_ddp.state_dict(), optimizer=optimizer.state_dict(), n_cls=model_without_ddp.n_cls, lr_scheduler=lr_scheduler.state_dict(), ) if loss_scaler is not None: snapshot["loss_scaler"] = loss_scaler.state_dict() snapshot["epoch"] = epoch torch.save(snapshot, checkpoint_path) # evaluate eval_epoch = epoch % eval_freq == 0 or epoch == num_epochs - 1 if eval_epoch: eval_logger = evaluate( model, val_loader, val_seg_gt, window_size, window_stride, amp_autocast, ) print(f"Stats [{epoch}]:", eval_logger, flush=True) print("") # log stats if ptu.dist_rank == 0: train_stats = { k: meter.global_avg for k, meter in train_logger.meters.items() } val_stats = {} if eval_epoch: val_stats = { k: meter.global_avg for k, meter in eval_logger.meters.items() } log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, **{f"val_{k}": v for k, v in val_stats.items()}, "epoch": epoch, "num_updates": (epoch + 1) * len(train_loader), } with open(log_dir / "log.txt", "a") as f: f.write(json.dumps(log_stats) + "\n") distributed.barrier() distributed.destroy_process() sys.exit(1)
class TorchImageClassificationEstimator(BaseEstimator): """Torch Estimator implementation for Image Classification. Parameters ---------- config : dict Config in nested dict. logger : logging.Logger Optional logger for this estimator, can be `None` when default setting is used. reporter : callable The reporter for metric checkpointing. net : torch.nn.Module The custom network. If defined, the model name in config will be ignored so your custom network will be used for training rather than pulling it from model zoo. """ def __init__(self, config, logger=None, reporter=None, net=None, optimizer=None, problem_type=None): super().__init__(config, logger=logger, reporter=reporter, name=None) if problem_type is None: problem_type = MULTICLASS self._problem_type = problem_type self._feature_net = None self._custom_net = False self._img_cls_cfg = self._cfg.img_cls self._data_cfg = self._cfg.data self._optimizer_cfg = self._cfg.optimizer self._train_cfg = self._cfg.train self._augmentation_cfg = self._cfg.augmentation self._model_ema_cfg = self._cfg.model_ema self._misc_cfg = self._cfg.misc # resolve AMP arguments based on PyTorch / Apex availability self.use_amp = None if self._misc_cfg.amp: # `amp` chooses native amp before apex (APEX ver not actively maintained) if self._misc_cfg.native_amp and has_native_amp: self.use_amp = 'native' elif self._misc_cfg.apex_amp and has_apex: self.use_amp = 'apex' elif self._misc_cfg.apex_amp or self._misc_cfg.native_amp: self._logger.warning(f'Neither APEX or native Torch AMP is available, using float32. \ Install NVIDA apex or upgrade to PyTorch 1.6') # FIXME: will provided model conflict with config provided? if net is not None: assert isinstance(net, nn.Module), f"given custom network {type(net)}, `torch.nn` expected" try: net.to('cpu') self._custom_net = True except ValueError: pass self.net = net if optimizer is not None: self._logger.warning('Custom optimizer object not supported. Will follow the config instead.') self._optimizer = None def _fit(self, train_data, val_data, time_limit=math.inf): tic = time.time() self._cp_name = '' self._best_acc = -float('inf') self.epochs = self._train_cfg.epochs self.epoch = 0 self.start_epoch = self._train_cfg.start_epoch self._time_elapsed = 0 if max(self.start_epoch, self.epoch) >= self.epochs: return {'time', self._time_elapsed} self._init_trainer() self._init_model_ema() self._time_elapsed += time.time() - tic return self._resume_fit(train_data, val_data, time_limit=time_limit) def _resume_fit(self, train_data, val_data, time_limit=math.inf): tic = time.time() # TODO: regression not implemented if self._problem_type != REGRESSION and (not self.classes or not self.num_class): raise ValueError('This is a classification problem and we are not able to determine classes of dataset') if max(self.start_epoch, self.epoch) >= self.epochs: return {'time': self._time_elapsed} # wrap DP if possible if self.found_gpu: self.net = torch.nn.DataParallel(self.net, device_ids=[int(i) for i in self.valid_gpus]) self.net = self.net.to(self.ctx[0]) # prepare dataset train_dataset = train_data.to_torch() val_dataset = val_data.to_torch() # setup mixup / cutmix self._collate_fn = None self._mixup_fn = None self.mixup_active = self._augmentation_cfg.mixup > 0 or self._augmentation_cfg.cutmix > 0. or self._augmentation_cfg.cutmix_minmax is not None if self.mixup_active: mixup_args = dict( mixup_alpha=self._augmentation_cfg.mixup, cutmix_alpha=self._augmentation_cfg.cutmix, cutmix_minmax=self._augmentation_cfg.cutmix_minmax, prob=self._augmentation_cfg.mixup_prob, switch_prob=self._augmentation_cfg.mixup_switch_prob, mode=self._augmentation_cfg.mixup_mode, label_smoothing=self._augmentation_cfg.smoothing, num_classes=self.num_class) if self._misc_cfg.prefetcher: self._collate_fn = FastCollateMixup(**mixup_args) else: self._mixup_fn = Mixup(**mixup_args) # create data loaders w/ augmentation pipeiine train_interpolation = self._augmentation_cfg.train_interpolation if self._augmentation_cfg.no_aug or not train_interpolation: train_interpolation = self._data_cfg.interpolation train_loader = create_loader( train_dataset, input_size=self._data_cfg.input_size, batch_size=self._train_cfg.batch_size, is_training=True, use_prefetcher=self._misc_cfg.prefetcher, no_aug=self._augmentation_cfg.no_aug, scale=self._augmentation_cfg.scale, ratio=self._augmentation_cfg.ratio, hflip=self._augmentation_cfg.hflip, vflip=self._augmentation_cfg.vflip, color_jitter=self._augmentation_cfg.color_jitter, auto_augment=self._augmentation_cfg.auto_augment, interpolation=train_interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, distributed=False, collate_fn=self._collate_fn, pin_memory=self._misc_cfg.pin_mem, use_multi_epochs_loader=self._misc_cfg.use_multi_epochs_loader ) val_loader = create_loader( val_dataset, input_size=self._data_cfg.input_size, batch_size=self._data_cfg.validation_batch_size_multiplier * self._train_cfg.batch_size, is_training=False, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, distributed=False, crop_pct=self._data_cfg.crop_pct, pin_memory=self._misc_cfg.pin_mem, ) self._time_elapsed += time.time() - tic return self._train_loop(train_loader, val_loader, time_limit=time_limit) def _train_loop(self, train_loader, val_loader, time_limit=math.inf): start_tic = time.time() # setup loss function if self.mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy() elif self._augmentation_cfg.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=self._augmentation_cfg.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() validate_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.to(self.ctx[0]) validate_loss_fn = validate_loss_fn.to(self.ctx[0]) eval_metric = self._misc_cfg.eval_metric if self._problem_type == REGRESSION: train_loss_fn = nn.MSELoss() validate_loss_fn = nn.MSELoss() eval_metric = 'rmse' early_stopper = EarlyStopperOnPlateau( patience=self._train_cfg.early_stop_patience, min_delta=self._train_cfg.early_stop_min_delta, baseline_value=self._train_cfg.early_stop_baseline, max_value=self._train_cfg.early_stop_max_value) self._logger.info('Start training from [Epoch %d]', max(self._train_cfg.start_epoch, self.epoch)) self._time_elapsed += time.time() - start_tic for self.epoch in range(max(self.start_epoch, self.epoch), self.epochs): epoch = self.epoch if self._best_acc >= 1.0: self._logger.info('[Epoch {}] Early stopping as acc is reaching 1.0'.format(epoch)) break should_stop, stop_message = early_stopper.get_early_stop_advice() if should_stop: self._logger.info('[Epoch {}] '.format(epoch) + stop_message) break train_metrics = self.train_one_epoch( epoch, self.net, train_loader, self._optimizer, train_loss_fn, lr_scheduler=self._lr_scheduler, output_dir=self._logdir, amp_autocast=self._amp_autocast, loss_scaler=self._loss_scaler, model_ema=self._model_ema, mixup_fn=self._mixup_fn, time_limit=time_limit) # reaching time limit, exit early if train_metrics['time_limit']: self._logger.warning(f'`time_limit={time_limit}` reached, exit early...') return {'train_acc': train_metrics['train_acc'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name} post_tic = time.time() eval_metrics = self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast) if self._model_ema is not None and not self._model_ema_cfg.model_ema_force_cpu: ema_eval_metrics = self.validate( self._model_ema.module, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast) eval_metrics = ema_eval_metrics if self._problem_type == REGRESSION: val_acc = eval_metrics['rmse'] if self._reporter: self._reporter(epoch=epoch, acc_reward=-val_acc) early_stopper.update(-val_acc) if -val_acc > self._best_acc: self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE) self._logger.info('[Epoch %d] Current best rmse: %f vs previous %f, saved to %s', self.epoch, val_acc, -self._best_acc, self._cp_name) self.save(self._cp_name) self._best_acc = -val_acc else: val_acc = eval_metrics['top1'] if self._reporter: self._reporter(epoch=epoch, acc_reward=val_acc) early_stopper.update(val_acc) if val_acc > self._best_acc: self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE) self._logger.info('[Epoch %d] Current best top-1: %f vs previous %f, saved to %s', self.epoch, val_acc, self._best_acc, self._cp_name) self.save(self._cp_name) self._best_acc = val_acc if self._lr_scheduler is not None: # step LR for next epoch self._lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) self._time_elapsed += time.time() - post_tic if 'accuracy' in train_metrics: return {'train_acc': train_metrics['accuracy'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name} # rmse else: if self._problem_type == REGRESSION: return {'train_score': train_metrics['rmse'], 'valid_score': -self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name} # mixup else: return {'train_score': train_metrics['rmse'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name} def train_one_epoch( self, epoch, net, loader, optimizer, loss_fn, lr_scheduler=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None, time_limit=math.inf): start_tic = time.time() if self._augmentation_cfg.mixup_off_epoch and epoch >= self._augmentation_cfg.mixup_off_epoch: if self._misc_cfg.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False elif mixup_fn is not None: mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order losses_m = AverageMeter() train_metric_score_m = AverageMeter() net.train() num_updates = epoch * len(loader) self._time_elapsed += time.time() - start_tic tic = time.time() last_tic = time.time() train_metric_name = 'accuracy' batch_idx = 0 for batch_idx, (input, target) in enumerate(loader): b_tic = time.time() if self._time_elapsed > time_limit: return {'train_acc': train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': True} if self._problem_type == REGRESSION: target = target.to(torch.float32) if not self._misc_cfg.prefetcher: # prefetcher would move data to cuda by default input, target = input.to(self.ctx[0]), target.to(self.ctx[0]) if mixup_fn is not None: input, target = mixup_fn(input, target) with amp_autocast(): output = net(input) if self._problem_type == REGRESSION: output = output.flatten() loss = loss_fn(output, target) if self._problem_type == REGRESSION: train_metric_name = 'rmse' train_metric_score = rmse(output, target) else: if output.shape == target.shape: train_metric_name = 'rmse' train_metric_score = rmse(output, target) else: train_metric_score = accuracy(output, target)[0] / 100 losses_m.update(loss.item(), input.size(0)) train_metric_score_m.update(train_metric_score.item(), output.size(0)) optimizer.zero_grad() if loss_scaler is not None: loss_scaler( loss, optimizer, clip_grad=self._optimizer_cfg.clip_grad, clip_mode=self._optimizer_cfg.clip_mode, parameters=model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode), create_graph=second_order) else: loss.backward(create_graph=second_order) if self._optimizer_cfg.clip_grad is not None: dispatch_clip_grad( model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode), value=self._optimizer_cfg.clip_grad, mode=self._optimizer_cfg.clip_mode) optimizer.step() if model_ema is not None: model_ema.update(net) if self.found_gpu: torch.cuda.synchronize() num_updates += 1 if (batch_idx+1) % self._misc_cfg.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) self._logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f', epoch, batch_idx, self._train_cfg.batch_size*self._misc_cfg.log_interval/(time.time()-last_tic), train_metric_name, train_metric_score_m.avg, lr) last_tic = time.time() if self._misc_cfg.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) self._time_elapsed += time.time() - b_tic throughput = int(self._train_cfg.batch_size * batch_idx / (time.time() - tic)) self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score_m.avg) self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic) end_time = time.time() if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() self._time_elapsed += time.time() - end_time return {train_metric_name: train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': False} def validate(self, net, loader, loss_fn, amp_autocast=suppress, metric_name=None): losses_m = AverageMeter() top1_m = AverageMeter() top5_m = AverageMeter() rmse_m = AverageMeter() net.eval() with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): if not self._misc_cfg.prefetcher: input = input.to(self.ctx[0]) target = target.to(self.ctx[0]) with amp_autocast(): output = net(input) if self._problem_type == REGRESSION: output = output.flatten() if isinstance(output, (tuple, list)): output = output[0] if self._problem_type == REGRESSION: if metric_name: assert metric_name == 'rmse', f'{metric_name} metric not supported for regression.' val_metric_score = rmse(output, target) else: val_metric_score = accuracy(output, target, topk=(1, min(5, self.num_class))) # augmentation reduction reduce_factor = self._misc_cfg.tta if self._problem_type != REGRESSION and reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) reduced_loss = loss.data if self.found_gpu: torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) if self._problem_type == REGRESSION: rmse_score = val_metric_score rmse_m.update(rmse_score.item(), output.size(0)) else: acc1, acc5 = val_metric_score acc1 /= 100 acc5 /= 100 top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) if self._problem_type == REGRESSION: self._logger.info('[Epoch %d] validation: rmse=%f', self.epoch, rmse_m.avg) return {'loss': losses_m.avg, 'rmse': rmse_m.avg} else: self._logger.info('[Epoch %d] validation: top1=%f top5=%f', self.epoch, top1_m.avg, top5_m.avg) return {'loss': losses_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg} def _init_network(self, **kwargs): load_only = kwargs.get('load_only', False) if not self.num_class and self._problem_type != REGRESSION: raise ValueError('This is a classification problem and we are not able to create network when `num_class` is unknown. \ It should be inferred from dataset or resumed from saved states.') assert len(self.classes) == self.num_class # Disable syncBatchNorm as it's only supported on DDP if self._train_cfg.sync_bn: self._logger.info( 'Disable Sync batch norm as it is not supported for now.') update_cfg(self._cfg, {'train': {'sync_bn': False}}) # ctx self.found_gpu = False valid_gpus = [] if self._cfg.gpus: valid_gpus = self._torch_validate_gpus(self._cfg.gpus) self.found_gpu = True if not valid_gpus: self.found_gpu = False self._logger.warning( 'No gpu detected, fallback to cpu. You can ignore this warning if this is intended.') elif len(valid_gpus) != len(self._cfg.gpus): self._logger.warning( f'Loaded on gpu({valid_gpus}), different from gpu({self._cfg.gpus}).') self.ctx = [torch.device(f'cuda:{gid}') for gid in valid_gpus] if self.found_gpu else [torch.device('cpu')] self.valid_gpus = valid_gpus if not self.found_gpu and self.use_amp: self.use_amp = None self._logger.warning('Training on cpu. AMP disabled.') update_cfg(self._cfg, {'misc': {'amp': False, 'apex_amp': False, 'native_amp': False}}) if not self.found_gpu and self._misc_cfg.prefetcher: self._logger.warning( 'Training on cpu. Prefetcher disabled.') update_cfg(self._cfg, {'misc': {'prefetcher': False}}) self._logger.warning( 'Training on cpu. SyncBatchNorm disabled.') update_cfg(self._cfg, {'train': {'sync_bn': False}}) random_seed(self._misc_cfg.seed) if not self.net: self.net = create_model( self._img_cls_cfg.model, pretrained=self._img_cls_cfg.pretrained and not load_only, num_classes=max(self.num_class, 1), global_pool=self._img_cls_cfg.global_pool_type, drop_rate=self._augmentation_cfg.drop, drop_path_rate=self._augmentation_cfg.drop_path, drop_block_rate=self._augmentation_cfg.drop_block, bn_momentum=self._train_cfg.bn_momentum, bn_eps=self._train_cfg.bn_eps, scriptable=self._misc_cfg.torchscript ) self._logger.info(f'Model {safe_model_name(self._img_cls_cfg.model)} created, param count: \ {sum([m.numel() for m in self.net.parameters()])}') else: self._logger.info(f'Use user provided model. Neglect model in config.') out_features = list(self.net.children())[-1].out_features if self._problem_type != REGRESSION: assert out_features == self.num_class, f'Custom model out_feature {out_features} != num_class {self.num_class}.' else: assert out_features == 1, f'Regression problem expects num_out_feature == 1, got {out_features} instead.' resolve_data_config(self._cfg, model=self.net) self.net = self.net.to(self.ctx[0]) # setup synchronized BatchNorm if self._train_cfg.sync_bn: if has_apex and self.use_amp != 'native': # Apex SyncBN preferred unless native amp is activated self.net = convert_syncbn_model(self.net) else: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) self._logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if self._misc_cfg.torchscript: assert not self.use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not self._train_cfg.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' self.net = torch.jit.script(self.net) def _init_trainer(self): if self._optimizer is None: if self._img_cls_cfg.pretrained and not self._custom_net \ and (self._train_cfg.transfer_lr_mult != 1 or self._train_cfg.output_lr_mult != 1): # adjust feature/last_fc learning rate multiplier in optimizer self._logger.debug(f'Reduce network lr multiplier to {self._train_cfg.transfer_lr_mult}, while keep ' + f'last FC layer lr_mult to {self._train_cfg.output_lr_mult}') optim_kwargs = optimizer_kwargs(cfg=self._cfg) optim_kwargs['feature_lr_mult'] = self._cfg.train.transfer_lr_mult optim_kwargs['fc_lr_mult'] = self._cfg.train.output_lr_mult self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg)) else: self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg)) self._init_loss_scaler() self._lr_scheduler, self.epochs = create_scheduler(self._cfg, self._optimizer) self._lr_scheduler.step(self.start_epoch, self.epoch) def _init_loss_scaler(self): # setup automatic mixed-precision (AMP) loss scaling and op casting self._amp_autocast = suppress # do nothing self._loss_scaler = None if self.use_amp == 'apex': self.net, self._optimizer = amp.initialize(self.net, self._optimizer, opt_level='O1') self._loss_scaler = ApexScaler() self._logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif self.use_amp == 'native': self._amp_autocast = torch.cuda.amp.autocast self._loss_scaler = NativeScaler() self._logger.info('Using native Torch AMP. Training in mixed precision.') else: self._logger.info('AMP not enabled. Training in float32.') def _init_model_ema(self): # Disable for now if self._model_ema_cfg.model_ema: self._logger.info('Disable EMA as it is not supported for now.') update_cfg(self._cfg, {'model_ema': {'model_ema': False}}) # setup exponential moving average of model weights, SWA could be used here too self._model_ema = None if self._model_ema_cfg.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper self._model_ema = ModelEmaV2( self.net, decay=self._model_ema_cfg.model_ema_decay, device='cpu' if self._model_ema_cfg.model_ema_force_cpu else None) def evaluate(self, val_data, metric_name=None): return self._evaluate(val_data, metric_name) def _evaluate(self, val_data, metric_name=None): if self._problem_type == REGRESSION: validate_loss_fn = nn.MSELoss() else: validate_loss_fn = nn.CrossEntropyLoss() validate_loss_fn = validate_loss_fn.to(self.ctx[0]) val_data = val_data.to_torch() val_loader = create_loader( val_data, input_size=self._data_cfg.input_size, batch_size=self._data_cfg.validation_batch_size_multiplier * self._train_cfg.batch_size, is_training=False, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, distributed=False, crop_pct=self._data_cfg.crop_pct, pin_memory=self._misc_cfg.pin_mem, ) return self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast, metric_name=metric_name) def _predict(self, x, **kwargs): with_proba = kwargs.get('with_proba', False) if with_proba and self._problem_type not in [MULTICLASS, BINARY]: raise AssertionError('with_proba is only supported for classification problems. Please use predict instead.') if isinstance(x, str): return self._predict((x,), **kwargs).drop(columns=['image'], errors='ignore') elif isinstance(x, pd.DataFrame): assert 'image' in x.columns, "Expect column `image` for input images" df = self._predict(tuple(x['image']), **kwargs) return df.reset_index(drop=True) elif isinstance(x, (list, tuple)): loader = create_loader( ImageListDataset(x), input_size=self._data_cfg.input_size, batch_size=self._train_cfg.batch_size, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, crop_pct=self._data_cfg.crop_pct ) self.net.eval() topk = min(5, self.num_class) results = [] idx = 0 with torch.no_grad(): for input, _ in loader: input = input.to(self.ctx[0]) labels = self.net(input) for l in labels: if self._problem_type in [MULTICLASS, BINARY]: probs = nn.functional.softmax(l, dim=0).cpu().numpy().flatten() if with_proba: results.append({'image_proba': probs.tolist(), 'image': x[idx]}) else: topk_inds = l.topk(topk)[1].cpu().numpy().flatten() results.extend([{'class': self.classes[topk_inds[k]], 'score': probs[topk_inds[k]], 'id': topk_inds[k], 'image': x[idx]} for k in range(topk)]) else: results.append({'prediction': l.cpu().numpy().flatten(), 'image': x[idx]}) idx += 1 return pd.DataFrame(results) elif not isinstance(x, torch.Tensor): raise ValueError('Input is not supported: {}'.format(type(x))) assert len(x.shape) == 4 and x.shape[1] == 3, f"Expect input to be (n, 3, h, w), given {x.shape}" with torch.no_grad(): input = x.to(self.ctx[0]) label = self.net(input) if self._problem_type in [MULTICLASS, BINARY]: topk = min(5, self.num_class) probs = nn.functional.softmax(label, dim=1).cpu().numpy().flatten() topk_inds = label.topk(topk)[1].cpu().numpy().flatten() if with_proba: df = pd.DataFrame([{'image_proba': probs.tolist()}]) else: df = pd.DataFrame([{'class': self.classes[topk_inds[k]], 'score': probs[topk_inds[k]], 'id': topk_inds[k]} for k in range(topk)]) else: df = pd.DataFrame([{'prediction': label.cpu().numpy().flatten()}]) return df def _predict_feature(self, x, **kwargs): if isinstance(x, str): return self._predict_feature((x,)) elif isinstance(x, pd.DataFrame): assert 'image' in x.columns, "Expect column `image` for input images" df = self._predict_feature(tuple(x['image'])) df = df.set_index(x.index) df['image'] = x['image'] return df elif isinstance(x, (list, tuple)): assert isinstance(x[0], str), "expect image paths in list/tuple input" loader = create_loader( ImageListDataset(x), input_size=self._data_cfg.input_size, batch_size=self._train_cfg.batch_size, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, crop_pct=self._data_cfg.crop_pct ) self.net.eval() results = [] with torch.no_grad(): for input, _ in loader: input = input.to(self.ctx[0]) try: features = self.net.forward_features(input) except AttributeError: features = self.net.module.forward_features(input) for f in features: f = f.cpu().numpy().flatten() results.append({'image_feature': f}) df = pd.DataFrame(results) df['image'] = x return df elif not isinstance(x, torch.Tensor): raise ValueError('Input is not supported: {}'.format(type(x))) with torch.no_grad(): input = x.to(self.ctx[0]) feature = self.net.forward_features(input) result = [{'image_feature': feature}] df = pd.DataFrame(result) return df def _reconstruct_state_dict(self, state_dict): new_state_dict = {} for k, v in state_dict.items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v return new_state_dict def save(self, filename): d = dict() current_states = self.__dict__.copy() if self.net: if not self._custom_net: if isinstance(self.net, torch.nn.DataParallel): d['model_state_dict'] = get_state_dict(self.net.module, unwrap_model) else: d['model_state_dict'] = get_state_dict(self.net, unwrap_model) else: net_pickle = pickle.dumps(self.net) d['net_pickle'] = net_pickle self.net = None if self._optimizer: d['optimizer_state_dict'] = self._optimizer.state_dict() self._optimizer = None if hasattr(self, '_loss_scaler') and self._loss_scaler: d[self._loss_scaler.state_dict_key] = self._loss_scaler.state_dict() d['_loss_scaler_state_dict_key'] = self._loss_scaler.state_dict_key if self._model_ema: d['ema_state_dict'] = get_state_dict(self._model_ema, unwrap_model) self._model_ema = None self._logger = None self._reporter = None d['estimator'] = self torch.save(d, filename) self.__dict__.update(current_states) @classmethod def load(cls, filename, ctx='auto'): d = torch.load(filename, map_location=torch.device('cpu')) est = d.pop('estimator') # logger est._logger = logging.getLogger(cls.__name__) est._logger.setLevel(logging.ERROR) try: fh = logging.FileHandler(est._log_file) est._logger.addHandler(fh) #pylint: disable=bare-except except: pass model_state_dict = d.get('model_state_dict', None) net_pickle = d.get('net_pickle', None) if model_state_dict: est._init_network(load_only=True) net_state_dict = est._reconstruct_state_dict(model_state_dict) if isinstance(est.net, torch.nn.DataParallel): est.net.module.load_state_dict(net_state_dict) else: est.net.load_state_dict(net_state_dict) elif net_pickle: est.net = pickle.loads(net_pickle) optimizer_state_dict = d.get('optimizer_state_dict', None) if optimizer_state_dict: est._init_trainer() est._optimizer.load_state_dict(optimizer_state_dict) if hasattr(est, '_loss_scaler') and est._loss_scaler: loss_scaler_state_dict_key = d.get('loss_scaler_state_dict') loss_scaler_dict = d.get(loss_scaler_state_dict_key, None) if loss_scaler_dict: est._loss_scaler.load_state_dict(loss_scaler_dict) ema_state_dict = d.get('ema_state_dict', None) est._init_model_ema() if ema_state_dict: ema_state_dict = est._reconstruct_state_dict(ema_state_dict) if isinstance(est.net, torch.nn.DataParallel): est._model_ema.module.module.load_state_dict(ema_state_dict) else: est._model_ema.module.load_state_dict(ema_state_dict) new_ctx = _suggest_load_context(est.net, ctx, est.ctx) est.reset_ctx(new_ctx) est._logger.setLevel(logging.INFO) return est # pylint: disable=redefined-outer-name, reimported def __getstate__(self): d = self.__dict__.copy() try: import torch net = d.pop('net', None) model_ema = d.pop('_model_ema', None) optimizer = d.pop('_optimizer', None) loss_scaler = d.pop('_loss_scaler', None) save_state = {} if net is not None: if not self._custom_net: if isinstance(net, torch.nn.DataParallel): save_state['state_dict'] = get_state_dict(net.module, unwrap_model) else: save_state['state_dict'] = get_state_dict(net, unwrap_model) else: net_pickle = pickle.dumps(net) save_state['net_pickle'] = net_pickle if optimizer is not None: save_state['optimizer'] = optimizer.state_dict() if loss_scaler is not None: save_state[loss_scaler.state_dict_key] = loss_scaler.state_dict() if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model) except ImportError: pass d['save_state'] = save_state d['_logger'] = None d['_reporter'] = None return d def __setstate__(self, state): save_state = state.pop('save_state', None) self.__dict__.update(state) # logger self._logger = logging.getLogger(state.get('_name', self.__class__.__name__)) self._logger.setLevel(logging.ERROR) try: fh = logging.FileHandler(self._log_file) self._logger.addHandler(fh) #pylint: disable=bare-except except: pass if not save_state: self.net = None self._optimizer = None self._logger.setLevel(logging.INFO) return try: import torch self.net = None self._optimizer = None if self._custom_net: if save_state.get('net_pickle', None): self.net = pickle.loads(save_state['net_pickle']) else: if save_state.get('state_dict', None): self._init_network(load_only=True) net_state_dict = self._reconstruct_state_dict(save_state['state_dict']) if isinstance(self.net, torch.nn.DataParallel): self.net.module.load_state_dict(net_state_dict) else: self.net.load_state_dict(net_state_dict) if save_state.get('optimizer', None): self._init_trainer() self._optimizer.load_state_dict(save_state['optimizer']) if hasattr(self, '_loss_scaler') and self._loss_scaler and self._loss_scaler.state_dict_key in save_state: loss_scaler_dict = save_state[self._loss_scaler.state_dict_key] self._loss_scaler.load_state_dict(loss_scaler_dict) if save_state.get('state_dict_ema', None): self._init_model_ema() model_ema_dict = save_state.get('state_dict_ema') model_ema_dict = self._reconstruct_state_dict(model_ema_dict) if isinstance(self.net, torch.nn.DataParallel): self._model_ema.module.module.load_state_dict(model_ema_dict) else: self._model_ema.module.load_state_dict(model_ema_dict) except ImportError: pass self._logger.setLevel(logging.INFO)
def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int( 1.0 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") # model = create_model( # args.model, # pretrained=False, # num_classes=args.nb_classes, # drop_rate=args.drop, # drop_path_rate=args.drop_path, # drop_block_rate=None, # ) model = getattr(SwinTransformer, args.model)(num_classes=args.nb_classes, drop_path_rate=args.drop_path) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr linear_scaled_warmup_lr = args.warmup_lr * args.batch_size * utils.get_world_size( ) / 512.0 args.warmup_lr = linear_scaled_warmup_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) # criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) lr_scheduler.step(epoch + 1) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, mixup_fn, set_training_mode=True # keep in eval mode during finetuning ) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) update_config_from_file(args.cfg) print(args) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) sampler_train = torch.utils.data.RandomSampler(dataset_train) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_size=args.batch_size // 2, sampler=sampler_val, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print("Creating SuperVisionTransformer") print(cfg) model = Vision_TransformerSuper(img_size=args.input_size, patch_size=args.patch_size, embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH, num_heads=cfg.SUPERNET.NUM_HEADS,mlp_ratio=cfg.SUPERNET.MLP_RATIO, qkv_bias=True, drop_rate=args.drop, drop_path_rate=args.drop_path, gp=args.gp, num_classes=args.nb_classes, max_relative_position=args.max_relative_position, relative_position=args.relative_position, change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos) choices = {'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO, 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM , 'depth': cfg.SEARCH_SPACE.DEPTH} model.to(device) if args.teacher_model: teacher_model = create_model( args.teacher_model, pretrained=True, num_classes=args.nb_classes, ) teacher_model.to(device) teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: teacher_model = None teacher_loss = None model_ema = None model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True) # save config for later experiments with open(file=output_dir / "config.yaml", mode='w') as f: f.write(args_text) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) retrain_config = None if args.mode == 'retrain' and "RETRAIN" in cfg: retrain_config = {'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM]*cfg.RETRAIN.DEPTH, 'num_heads': cfg.RETRAIN.NUM_HEADS,'mlp_ratio': cfg.RETRAIN.MLP_RATIO} trainer = AFSupernetTrainer( model, criterion, data_loader_train, data_loader_val, optimizer, device, args.epochs, loss_scaler, args.clip_grad, model_ema, mixup_fn, args.amp, teacher_model, teacher_loss,choices, args.mode, retrain_config, 0., output_dir, lr_scheduler, ) if args.eval: trainer._validate_one_epoch(-1) return trainer.fit()
def main(args): utils.my_init_distributed_mode(args) print(args) # if args.distillation_type != 'none' and args.finetune and not args.eval: # raise NotImplementedError("Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) # if True: # args.distributed: # num_tasks = utils.get_world_size() # global_rank = utils.get_rank() # if args.repeated_aug: # sampler_train = RASampler( # dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True # ) # else: # sampler_train = torch.utils.data.DistributedSampler( # dataset_train, # # num_replicas=num_tasks, # num_replicas=0, # rank=global_rank, shuffle=True # ) # if args.dist_eval: # if len(dataset_val) % num_tasks != 0: # print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' # 'This will slightly alter validation results as extra duplicate entries are added to achieve ' # 'equal num of samples per-process.') # sampler_val = torch.utils.data.DistributedSampler( # dataset_val, # # num_replicas=num_tasks, # num_replicas=0, # rank=global_rank, shuffle=False) # else: # sampler_val = torch.utils.data.SequentialSampler(dataset_val) # else: # sampler_train = torch.utils.data.RandomSampler(dataset_train) # sampler_val = torch.utils.data.SequentialSampler(dataset_val) # # data_loader_train = torch.utils.data.DataLoader( # dataset_train, sampler=sampler_train, # batch_size=args.batch_size, # num_workers=args.num_workers, # pin_memory=args.pin_mem, # drop_last=True, # ) # # data_loader_val = torch.utils.data.DataLoader( # dataset_val, sampler=sampler_val, # batch_size=int(1.5 * args.batch_size), # num_workers=args.num_workers, # pin_memory=args.pin_mem, # drop_last=False # ) # if args.distributed: if args.cache_mode: sampler_train = samplers.NodeDistributedSampler(dataset_train) sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) else: sampler_train = samplers.DistributedSampler(dataset_train) sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, num_workers=args.num_workers, pin_memory=True) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) model_without_ddp = model # # there are bugs # if args.finetune: # if args.finetune.startswith('https'): # checkpoint = torch.hub.load_state_dict_from_url( # args.finetune, map_location='cpu', check_hash=True) # else: # checkpoint = torch.load(args.finetune, map_location='cpu') # # checkpoint_model = checkpoint['model'] # # state_dict = model.state_dict() # for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: # if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: # print(f"Removing key {k} from pretrained checkpoint") # del checkpoint_model[k] # # _ = model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None # if args.model_ema: # # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # model_ema = ModelEma( # model, # decay=args.model_ema_decay, # device='cpu' if args.model_ema_force_cpu else '', # resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() criterion = DistillationLoss(criterion, None, 'none', 0, 0) output_dir = Path(args.output_dir) # for finetune if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.finetune, map_location='cpu', check_hash=True) else: if not os.path.exists(args.finetune): checkpoint = None print('NOTICE:' + args.finetune + ' does not exist!') else: checkpoint = torch.load(args.finetune, map_location='cpu') if checkpoint is not None: if 'model' in checkpoint: check_model = checkpoint['model'] else: check_model = checkpoint missing_keys = model_without_ddp.load_state_dict( check_model, strict=False).missing_keys skip_keys = model_without_ddp.no_weight_decay() # create optimizer manually param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if n in missing_keys and n not in skip_keys ], "lr": args.lr, 'weight_decay': args.weight_decay, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n in missing_keys and n in skip_keys ], "lr": args.lr, 'weight_decay': 0, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n not in missing_keys and n not in skip_keys ], "lr": args.lr * args.fine_factor, 'weight_decay': args.weight_decay, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n not in missing_keys and n in skip_keys ], "lr": args.lr * args.fine_factor, 'weight_decay': 0, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) # if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: # optimizer.load_state_dict(checkpoint['optimizer']) # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) # args.start_epoch = checkpoint['epoch'] + 1 # # if args.model_ema: # # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) # if 'scaler' in checkpoint: # loss_scaler.load_state_dict(checkpoint['scaler']) print('finetune from' + args.finetune) # for debug # lr_scheduler.step(10) # lr_scheduler.step(100) # lr_scheduler.step(200) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: if not os.path.exists(args.resume): checkpoint = None print('NOTICE:' + args.resume + ' does not exist!') else: checkpoint = torch.load(args.resume, map_location='cpu') if checkpoint is not None: if 'model' in checkpoint: model_without_ddp.load_state_dict(checkpoint['model']) else: model_without_ddp.load_state_dict(checkpoint) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 # if args.model_ema: # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print('resume from' + args.resume) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 max_epoch_dp_warm_up = 100 if 'pvt_tiny' in args.model or 'pvt_small' in args.model: max_epoch_dp_warm_up = 0 if args.start_epoch < max_epoch_dp_warm_up: model_without_ddp.reset_drop_path(0.0) for epoch in range(args.start_epoch, args.epochs): if args.fp32_resume and epoch > args.start_epoch + 1: args.fp32_resume = False loss_scaler._scaler = torch.cuda.amp.GradScaler( enabled=not args.fp32_resume) if epoch == max_epoch_dp_warm_up: model_without_ddp.reset_drop_path(args.drop_path) if epoch < args.warmup_epochs: optimizer.param_groups[2]['lr'] = 0 optimizer.param_groups[3]['lr'] = 0 if args.distributed: # data_loader_train.sampler.set_epoch(epoch) sampler_train.set_epoch(epoch) train_stats = my_train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, # set_training_mode=args.finetune == '', # keep in eval mode during finetuning fp32=args.fp32_resume) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, # 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))