def main(): args = parser.parse_args() args.best_top1 = 0. args.best_top5 = 0. if args.local_rank != -1: args.gpu = args.local_rank torch.distributed.init_process_group(backend='nccl') args.world_size = torch.distributed.get_world_size() else: args.gpu = 0 args.world_size = 1 args.device = torch.device('cuda', args.gpu) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARNING) logger.warning(f"Process rank: {args.local_rank}, " f"device: {args.device}, " f"distributed training: {bool(args.local_rank != -1)}, " f"16-bits training: {args.amp}") logger.info(dict(args._get_kwargs())) if args.local_rank in [-1, 0]: args.writer = SummaryWriter(f"results/{args.name}") if args.seed is not None: set_seed(args) if args.local_rank not in [-1, 0]: torch.distributed.barrier() labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[ args.dataset](args) if args.local_rank == 0: torch.distributed.barrier() train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler labeled_loader = DataLoader(labeled_dataset, sampler=train_sampler(labeled_dataset), batch_size=args.batch_size, num_workers=args.workers, drop_last=True) unlabeled_loader = DataLoader(unlabeled_dataset, sampler=train_sampler(unlabeled_dataset), batch_size=args.batch_size * args.mu, num_workers=args.workers, drop_last=True) test_loader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.batch_size, num_workers=args.workers) if args.dataset == "cifar10": depth, widen_factor = 28, 2 elif args.dataset == 'cifar100': depth, widen_factor = 28, 8 if args.local_rank not in [-1, 0]: torch.distributed.barrier() # test dropout teacher_model = WideResNet(num_classes=args.num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=args.dense_dropout) student_model = WideResNet(num_classes=args.num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=args.dense_dropout) if args.local_rank == 0: torch.distributed.barrier() teacher_model.to(args.device) student_model.to(args.device) avg_student_model = None if args.ema > 0: avg_student_model = ModelEMA(student_model, args.ema) criterion = create_loss_fn(args) no_decay = ['bn'] teacher_parameters = [{ 'params': [ p for n, p in teacher_model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in teacher_model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] student_parameters = [{ 'params': [ p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] t_optimizer = optim.SGD( teacher_parameters, lr=args.lr, momentum=args.momentum, # weight_decay=args.weight_decay, nesterov=args.nesterov) s_optimizer = optim.SGD( student_parameters, lr=args.lr, momentum=args.momentum, # weight_decay=args.weight_decay, nesterov=args.nesterov) t_scheduler = get_cosine_schedule_with_warmup(t_optimizer, args.warmup_steps, args.total_steps) s_scheduler = get_cosine_schedule_with_warmup(s_optimizer, args.warmup_steps, args.total_steps, args.student_wait_steps) t_scaler = amp.GradScaler(enabled=args.amp) s_scaler = amp.GradScaler(enabled=args.amp) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): logger.info(f"=> loading checkpoint '{args.resume}'") loc = f'cuda:{args.gpu}' checkpoint = torch.load(args.resume, map_location=loc) args.best_top1 = checkpoint['best_top1'].to(torch.device('cpu')) args.best_top5 = checkpoint['best_top5'].to(torch.device('cpu')) if not (args.evaluate or args.finetune): args.start_step = checkpoint['step'] t_optimizer.load_state_dict(checkpoint['teacher_optimizer']) s_optimizer.load_state_dict(checkpoint['student_optimizer']) t_scheduler.load_state_dict(checkpoint['teacher_scheduler']) s_scheduler.load_state_dict(checkpoint['student_scheduler']) t_scaler.load_state_dict(checkpoint['teacher_scaler']) s_scaler.load_state_dict(checkpoint['student_scaler']) model_load_state_dict(teacher_model, checkpoint['teacher_state_dict']) if avg_student_model is not None: model_load_state_dict(avg_student_model, checkpoint['avg_state_dict']) else: if checkpoint['avg_state_dict'] is not None: model_load_state_dict(student_model, checkpoint['avg_state_dict']) else: model_load_state_dict(student_model, checkpoint['student_state_dict']) logger.info( f"=> loaded checkpoint '{args.resume}' (step {checkpoint['step']})" ) else: logger.info(f"=> no checkpoint found at '{args.resume}'") if args.local_rank != -1: teacher_model = nn.parallel.DistributedDataParallel( teacher_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) student_model = nn.parallel.DistributedDataParallel( student_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) if args.finetune: del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader del s_scaler, s_scheduler, s_optimizer finetune(args, labeled_loader, test_loader, student_model, criterion) return if args.evaluate: del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader, labeled_loader del s_scaler, s_scheduler, s_optimizer evaluate(args, test_loader, student_model, criterion) return teacher_model.zero_grad() student_model.zero_grad() train_loop(args, labeled_loader, unlabeled_loader, test_loader, teacher_model, student_model, avg_student_model, criterion, t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler, s_scaler) return
num_classes=n_classes) elif args.model == 'resnext': net = CifarResNeXt(cardinality=8, depth=29, base_width=64, widen_factor=4, nlabels=n_classes) else: raise Exception('Invalid model name') # create optimizer optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay) net.to('cuda') if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) cudnn.benchmark = True criterion = nn.CrossEntropyLoss().cuda() # trainer if args.adversarial: if args.regu == 'no': trainer = AdversarialTrainer(net, criterion, optimizer, args) elif args.regu == 'random-svd': trainer = AdversarialOrthReguTrainer(net, criterion, optimizer, args) else: raise Exception('Invalid setting for adversarial training') else: if args.regu == 'no':