def main(): # arg parser args = arg_parser() # set seed set_seed(args.seed) # dataset id_traindata = datasets.CIFAR10('./data/', train=True, download=True) id_testdata = datasets.CIFAR10('./data/', train=False, download=True) id_traindata = RotDataset(id_traindata, train_mode=True) id_testdata = RotDataset(id_testdata, train_mode=False) # data loader if args.method == 'rot' or args.method == 'msp': train_loader = dataloader(id_traindata, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) else: raise ValueError(args.method) test_loader = dataloader(id_testdata, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) # model num_classes = 10 model = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate) model.rot_head = nn.Linear(128, 4) model = model.cuda() # optimizer optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.decay, nesterov=True) # training for epoch in range(1, args.epochs + 1): train_loss = train(args, epoch, model, train_loader, optimizer, lr_scheduler=None) test_loss, test_acc = test(args, model, test_loader) print('epoch:{}, train_loss:{}, test_loss:{}, test_acc:{}'.format( epoch, round(train_loss.item(), 4), round(test_loss.item(), 4), round(test_acc, 4))) torch.save(model.state_dict(), './trained_model_{}.pth'.format(args.method))
def main(): # arg parser args = arg_parser() # set seed set_seed(args.seed) # dataset id_testdata = datasets.CIFAR10('./data/', train=False, download=True) id_testdata = RotDataset(id_testdata, train_mode=False) if args.ood_dataset == 'cifar100': ood_testdata = datasets.CIFAR100('./data/', train=False, download=True) elif args.ood_dataset == 'svhn': ood_testdata = datasets.SVHN('./data/', split='test', download=True) else: raise ValueError(args.ood_dataset) ood_testdata = RotDataset(ood_testdata, train_mode=False) # data loader id_test_loader = dataloader(id_testdata, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) ood_test_loader = dataloader(ood_testdata, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) # load model num_classes = 10 model = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate) model.rot_head = nn.Linear(128, 4) model = model.cuda() model.load_state_dict(torch.load('./models/trained_model_{}.pth'.format(args.method))) TODO:
def main(): print("Using CIFAR 10") train_data_in = dset.CIFAR10('/home/jiuhai.chen/data', train=True, download=True) test_data = dset.CIFAR10('/home/jiuhai.chen/data', train=False, download=True) num_classes = 10 # 0 airplane, 1 automobile, 2 bird, 3 cat, 4 deer, 5 dog, 6 frog, 7 horse, 8 ship, 9 truck # Must do != None to make sure 0 case works if args.in_class != None: print("Removing all but class {0} from train dataset and test dataset". format(args.in_class)) train_data_in.data = train_data_in.data[ train_data_in.targets == args.in_class * np.ones_like(train_data_in.targets)] test_data.data = test_data.data[test_data.targets == args.in_class * np.ones_like(test_data.targets)] else: print("Keeping all classes in both train/test datasets") train_data_in = PerturbDataset(train_data_in, train_mode=True) test_data = PerturbDataset(test_data, train_mode=False) train_loader_in = torch.utils.data.DataLoader(train_data_in, batch_size=args.batch_size, shuffle=True, num_workers=args.prefetch, pin_memory=False) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=False) # Create model net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate) net.x_trans_head = nn.Linear(128, 3) net.y_trans_head = nn.Linear(128, 3) net.rot_head = nn.Linear(128, 4) # Get GPUs ready if args.ngpu > 1: net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) if args.ngpu > 0: net.cuda() torch.cuda.manual_seed(1) cudnn.benchmark = True # fire on all cylinders # Set up optimization stuffs optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) def cosine_annealing(step, total_steps, lr_max, lr_min): return lr_min + (lr_max - lr_min) * 0.5 * ( 1 + np.cos(step / total_steps * np.pi)) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: cosine_annealing( step, args.epochs * len(train_loader_in), 1, # since lr_lambda computes multiplicative factor 1e-6 / args.learning_rate)) print('Beginning Training\n') # Main loop for epoch in range(0, args.epochs): state['epoch'] = epoch begin_epoch = time.time() train(net, state, train_loader_in, optimizer, lr_scheduler) test(net, state, test_loader) # Save model torch.save( net.state_dict(), os.path.join( args.save, 'layers_{0}_widenfactor_{1}_inclass_{2}_transform_trflossweight_{3}_epoch_{4}.pt' .format( str(args.layers), str(args.widen_factor), str(args.in_class), str(args.rot_loss_weight) + "_" + str(args.transl_loss_weight), str(epoch)))) # Let us not waste space and delete the previous model prev_path = os.path.join( args.save, 'layers_{0}_widenfactor_{1}_inclass_{2}_transform_trflossweight_{3}_epoch_{4}.pt' .format( str(args.layers), str(args.widen_factor), str(args.in_class), str(args.rot_loss_weight) + "_" + str(args.transl_loss_weight), str(epoch - 1))) if os.path.exists(prev_path): os.remove(prev_path) # Show results print( 'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Accuracy {4:.3f}%' .format((epoch + 1), int(time.time() - begin_epoch), state['train_loss'], state['test_loss'], state['test_accuracy'] * 100))