break model_name = os.path.join(args.load, args.dataset + '_' + args.model + '_' + args.model + '_baseline_epoch_' + str(i) + '.pt') if os.path.isfile(model_name): net.load_state_dict(torch.load(model_name)) print('Model restored! Epoch:', i) start_epoch = i + 1 break if start_epoch == 0: assert False, "could not resume" if args.ngpu > 1: net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) if args.ngpu > 0: net.cuda() torch.cuda.manual_seed(1) cudnn.benchmark = True # fire on all cylinders net.eval() concat = lambda x: np.concatenate(x, axis=0) to_np = lambda x: x.data.to('cpu').numpy() def evaluate(loader): confidence = [] correct = [] num_correct = 0 with torch.no_grad():
def main(): train_transform = trn.Compose([ trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4), trn.ToTensor(), trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) test_transform = trn.Compose([ trn.ToTensor(), trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) if args.dataset == 'cifar10': print("Using CIFAR 10") train_data_in = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset', train=True, transform=train_transform) test_data = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset', train=False, transform=test_transform) num_classes = 10 else: print("Using CIFAR100") train_data_in = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset', train=True, transform=train_transform) test_data = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset', train=False, transform=test_transform) num_classes = 100 train_loader_in = torch.utils.data.DataLoader(train_data_in, batch_size=args.batch_size, shuffle=True, num_workers=args.prefetch, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=True) net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate) net.cuda() optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) def cosine_annealing(step, total_steps, lr_max, lr_min): return lr_min + (lr_max - lr_min) * 0.5 * ( 1 + np.cos(step / total_steps * np.pi)) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: cosine_annealing( step, args.epochs * len(train_loader_in), 1, # since lr_lambda computes multiplicative factor 1e-6 / args.learning_rate)) # Make save directory if not os.path.exists(args.save): os.makedirs(args.save) if not os.path.isdir(args.save): raise Exception('%s is not a dir' % args.save) print('Beginning Training\n') with open(os.path.join(args.save, "training_log.csv"), 'w') as f: f.write() # Main loop for epoch in range(0, args.epochs): state['epoch'] = epoch begin_epoch = time.time() train(net, state, train_loader_in, optimizer, lr_scheduler) test(net, state, test_loader) # Save model torch.save( net.state_dict(), os.path.join( args.save, '{0}_{1}_layers_{2}_widenfactor_{3}_transform_epoch_{4}.pt'. format(args.dataset, args.model, str(args.layers), str(args.widen_factor), str(epoch)))) # Let us not waste space and delete the previous model prev_path = os.path.join( args.save, '{0}_{1}_layers_{2}_widenfactor_{3}_transform_epoch_{4}.pt'.format( args.dataset, args.model, str(args.layers), str(args.widen_factor), str(epoch - 1))) if os.path.exists(prev_path): os.remove(prev_path) # Show results print( 'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}' .format((epoch + 1), int(time.time() - begin_epoch), state['train_loss'], state['test_loss'], 100 - 100. * state['test_accuracy']))