def main_worker(gpu, ngpus_per_node, args): global best_acc args.gpu = gpu assert args.gpu is not None print("Use GPU: {} for training".format(args.gpu)) log = open( os.path.join( args.save_path, 'log_seed{}{}.txt'.format(args.manualSeed, '_eval' if args.evaluate else '')), 'w') log = (log, args.gpu) net = models.__dict__[args.arch](pretrained=True) disable_dropout(net) net = to_bayesian(net, args.psi_init_range) net.apply(unfreeze) print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log) print_log("PyTorch version : {}".format(torch.__version__), log) print_log("CuDNN version : {}".format(torch.backends.cudnn.version()), log) print_log( "Number of parameters: {}".format( sum([p.numel() for p in net.parameters()])), log) print_log(str(args), log) if args.distributed: if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url + ":" + args.dist_port, world_size=args.world_size, rank=args.rank) torch.cuda.set_device(args.gpu) net.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu]) else: torch.cuda.set_device(args.gpu) net = net.cuda(args.gpu) criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu) mus, psis = [], [] for name, param in net.named_parameters(): if 'psi' in name: psis.append(param) else: mus.append(param) mu_optimizer = SGD(mus, args.learning_rate, args.momentum, weight_decay=args.decay, nesterov=(args.momentum > 0.0)) psi_optimizer = PsiSGD(psis, args.learning_rate, args.momentum, weight_decay=args.decay, nesterov=(args.momentum > 0.0)) recorder = RecorderMeter(args.epochs) if args.resume: if args.resume == 'auto': args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar') if os.path.isfile(args.resume): print_log("=> loading checkpoint '{}'".format(args.resume), log) checkpoint = torch.load(args.resume, map_location='cuda:{}'.format(args.gpu)) recorder = checkpoint['recorder'] recorder.refresh(args.epochs) args.start_epoch = checkpoint['epoch'] net.load_state_dict( checkpoint['state_dict'] if args.distributed else { k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items() }) mu_optimizer.load_state_dict(checkpoint['mu_optimizer']) psi_optimizer.load_state_dict(checkpoint['psi_optimizer']) best_acc = recorder.max_accuracy(False) print_log( "=> loaded checkpoint '{}' accuracy={} (epoch {})".format( args.resume, best_acc, checkpoint['epoch']), log) else: print_log("=> no checkpoint found at '{}'".format(args.resume), log) else: print_log("=> do not use any checkpoint for the model", log) cudnn.benchmark = True train_loader, ood_train_loader, test_loader, adv_loader, \ fake_loader, adv_loader2 = load_dataset_ft(args) psi_optimizer.num_data = len(train_loader.dataset) if args.evaluate: evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion, args, log, 20, 100) return start_time = time.time() epoch_time = AverageMeter() train_los = -1 for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_loader.sampler.set_epoch(epoch) ood_train_loader.sampler.set_epoch(epoch) cur_lr, cur_slr = adjust_learning_rate(mu_optimizer, psi_optimizer, epoch, args) need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format( time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \ + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log) train_acc, train_los = train(train_loader, ood_train_loader, net, criterion, mu_optimizer, psi_optimizer, epoch, args, log) val_acc, val_los = 0, 0 recorder.update(epoch, train_los, train_acc, val_acc, val_los) is_best = False if val_acc > best_acc: is_best = True best_acc = val_acc if args.gpu == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'recorder': recorder, 'mu_optimizer': mu_optimizer.state_dict(), 'psi_optimizer': psi_optimizer.state_dict(), }, False, args.save_path, 'checkpoint.pth.tar') epoch_time.update(time.time() - start_time) start_time = time.time() recorder.plot_curve(os.path.join(args.save_path, 'log.png')) evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion, args, log, 20, 100) log[0].close()
from models.wrn import wrn if __name__ == '__main__': parser = argparse.ArgumentParser() args = parser.parse_args() args.epochs = 1 args.dataset = 'cifar10' args.data_path = '/data/LargeData/Regular/cifar' args.cutout = True args.distributed = False args.batch_size = 32 args.workers = 4 train_loader, test_loader = load_dataset(args) net = wrn(pretrained=True, depth=28, width=10).cuda() disable_dropout(net) eval_loss, eval_acc = Bayes_ensemble(test_loader, net, num_mc_samples=1) print('Results of deterministic pre-training, ' 'eval loss {}, eval acc {}'.format(eval_loss, eval_acc)) bayesian_net = to_bayesian(net) bayesian_net.apply(unfreeze) mus, psis = [], [] for name, param in bayesian_net.named_parameters(): if 'psi' in name: psis.append(param) else: mus.append(param) mu_optimizer = SGD(mus, lr=0.0008, momentum=0.9,