def main(): global args, config, best_prec1 args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) config = EasyDict(config['common']) config.save_path = os.path.dirname(args.config) rank, world_size = dist_init() # create model bn_group_size = config.model.kwargs.bn_group_size bn_var_mode = config.model.kwargs.get('bn_var_mode', 'L2') if bn_group_size == 1: bn_group = None else: assert world_size % bn_group_size == 0 bn_group = simple_group_split(world_size, rank, world_size // bn_group_size) config.model.kwargs.bn_group = bn_group config.model.kwargs.bn_var_mode = (link.syncbnVarMode_t.L1 if bn_var_mode == 'L1' else link.syncbnVarMode_t.L2) model = model_entry(config.model) if rank == 0: print(model) model.cuda() if config.optimizer.type == 'FP16SGD' or config.optimizer.type == 'FusedFP16SGD': args.fp16 = True else: args.fp16 = False if args.fp16: # if you have modules that must use fp32 parameters, and need fp32 input # try use link.fp16.register_float_module(your_module) # if you only need fp32 parameters set cast_args=False when call this # function, then call link.fp16.init() before call model.half() if config.optimizer.get('fp16_normal_bn', False): print('using normal bn for fp16') link.fp16.register_float_module(link.nn.SyncBatchNorm2d, cast_args=False) link.fp16.register_float_module(torch.nn.BatchNorm2d, cast_args=False) link.fp16.init() model.half() model = DistModule(model, args.sync) # create optimizer opt_config = config.optimizer opt_config.kwargs.lr = config.lr_scheduler.base_lr if config.get('no_wd', False): param_group, type2num = param_group_no_wd(model) opt_config.kwargs.params = param_group else: opt_config.kwargs.params = model.parameters() optimizer = optim_entry(opt_config) # optionally resume from a checkpoint last_iter = -1 best_prec1 = 0 if args.load_path: if args.recover: best_prec1, last_iter = load_state(args.load_path, model, optimizer=optimizer) else: load_state(args.load_path, model) cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # augmentation aug = [ transforms.RandomResizedCrop(config.augmentation.input_size), transforms.RandomHorizontalFlip() ] for k in config.augmentation.keys(): assert k in [ 'input_size', 'test_resize', 'rotation', 'colorjitter', 'colorold' ] rotation = config.augmentation.get('rotation', 0) colorjitter = config.augmentation.get('colorjitter', None) colorold = config.augmentation.get('colorold', False) if rotation > 0: aug.append(transforms.RandomRotation(rotation)) if colorjitter is not None: aug.append(transforms.ColorJitter(*colorjitter)) aug.append(transforms.ToTensor()) if colorold: aug.append(ColorAugmentation()) aug.append(normalize) # train train_dataset = McDataset(config.train_root, config.train_source, transforms.Compose(aug), fake=args.fake) # val val_dataset = McDataset( config.val_root, config.val_source, transforms.Compose([ transforms.Resize(config.augmentation.test_resize), transforms.CenterCrop(config.augmentation.input_size), transforms.ToTensor(), normalize, ]), args.fake) train_sampler = DistributedGivenIterationSampler( train_dataset, config.lr_scheduler.max_iter, config.batch_size, last_iter=last_iter) val_sampler = DistributedSampler(val_dataset, round_up=False) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers, pin_memory=True, sampler=train_sampler) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers, pin_memory=True, sampler=val_sampler) config.lr_scheduler['optimizer'] = optimizer.optimizer if isinstance( optimizer, FP16SGD) else optimizer config.lr_scheduler['last_iter'] = last_iter lr_scheduler = get_scheduler(config.lr_scheduler) if rank == 0: tb_logger = SummaryWriter(config.save_path + '/events') logger = create_logger('global_logger', config.save_path + '/log.txt') logger.info('args: {}'.format(pprint.pformat(args))) logger.info('config: {}'.format(pprint.pformat(config))) else: tb_logger = None if args.evaluate: if args.fusion_list is not None: validate(val_loader, model, fusion_list=args.fusion_list, fuse_prob=args.fuse_prob) else: validate(val_loader, model) link.finalize() return train(train_loader, val_loader, model, optimizer, lr_scheduler, last_iter + 1, tb_logger) link.finalize()
def main(): global cfg, rank, world_size cfg = Config.fromfile(args.config) # Set seed np.random.seed(cfg.seed) cudnn.benchmark = True torch.manual_seed(cfg.seed) cudnn.enabled = True torch.cuda.manual_seed(cfg.seed) # Model print('==> Building model..') arch_code = eval('architecture_code.{}'.format(cfg.model)) net = models.model_entry(cfg, arch_code) rank = 0 # for non-distributed world_size = 1 # for non-distributed if args.distributed: print('==> Initializing distributed training..') init_dist( launcher='slurm', backend='nccl' ) # Only support slurm for now, if you would like to personalize your launcher, please refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py rank, world_size = get_dist_info() net = net.cuda() cfg.netpara = sum(p.numel() for p in net.parameters()) / 1e6 start_epoch = 0 best_acc = 0 # Load checkpoint. if cfg.get('resume_path', False): print('==> Resuming from {}checkpoint {}..'.format( ('original ' if cfg.resume_path.origin_ckpt else ''), cfg.resume_path.path)) if cfg.resume_path.origin_ckpt: utils.load_state(cfg.resume_path.path, net, rank=rank) else: if args.distributed: net = torch.nn.parallel.DistributedDataParallel( net, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) utils.load_state(cfg.resume_path.path, net, rank=rank) # Data print('==> Preparing data..') if args.eval_only: testloader = dataset_entry(cfg, args.distributed, args.eval_only) else: trainloader, testloader, train_sampler, test_sampler = dataset_entry( cfg, args.distributed, args.eval_only) print(trainloader, testloader, train_sampler, test_sampler) criterion = nn.CrossEntropyLoss() if not args.eval_only: cfg.attack_param.num_steps = 7 net_adv = AttackPGD(net, cfg.attack_param) if not args.eval_only: # Train params print('==> Setting train parameters..') train_param = cfg.train_param epochs = train_param.epochs init_lr = train_param.learning_rate if train_param.get('warm_up_param', False): warm_up_param = train_param.warm_up_param init_lr = warm_up_param.warm_up_base_lr epochs += warm_up_param.warm_up_epochs if train_param.get('no_wd', False): param_group, type2num, _, _ = utils.param_group_no_wd(net) cfg.param_group_no_wd = type2num optimizer = torch.optim.SGD(param_group, lr=init_lr, momentum=train_param.momentum, weight_decay=train_param.weight_decay) else: optimizer = torch.optim.SGD(net.parameters(), lr=init_lr, momentum=train_param.momentum, weight_decay=train_param.weight_decay) scheduler = lr_scheduler.CosineLRScheduler( optimizer, epochs, train_param.learning_rate_min, init_lr, train_param.learning_rate, (warm_up_param.warm_up_epochs if train_param.get( 'warm_up_param', False) else 0)) # Log print('==> Writing log..') if rank == 0: cfg.save = '{}/{}-{}-{}'.format(cfg.save_path, cfg.model, cfg.dataset, time.strftime("%Y%m%d-%H%M%S")) utils.create_exp_dir(cfg.save) logger = utils.create_logger('global_logger', cfg.save + '/log.txt') logger.info('config: {}'.format(pprint.pformat(cfg))) # Evaluation only if args.eval_only: assert cfg.get( 'resume_path', False), 'Should set the resume path for the eval_only mode' print('==> Testing on Clean Data..') test(net, testloader, criterion) print('==> Testing on Adversarial Data..') test(net_adv, testloader, criterion, adv=True) return # Training process for epoch in range(start_epoch, epochs): if args.distributed: train_sampler.set_epoch(epoch) test_sampler.set_epoch(epoch) scheduler.step() if rank == 0: logger.info('Epoch %d learning rate %e', epoch, scheduler.get_lr()[0]) # Train for one epoch train(net_adv, trainloader, criterion, optimizer) # Validate for one epoch valid_acc = test(net_adv, testloader, criterion, adv=True) if rank == 0: logger.info('Validation Accuracy: {}'.format(valid_acc)) is_best = valid_acc > best_acc best_acc = max(valid_acc, best_acc) print('==> Saving') state = { 'epoch': epoch, 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), 'state_dict': net.state_dict(), 'scheduler': scheduler } utils.save_checkpoint(state, is_best, os.path.join(cfg.save))