dropRate=drop_rate) model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=0.0005) scheduler = MultiStepLR(optimizer, milestones=LR_MILESTONES, gamma=gamma) try: checkpoint_fpath = 'cifar-10/cifar10_wideresnet79.pt' checkpoint = torch.load(checkpoint_fpath) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler = MultiStepLR(optimizer, milestones=LR_MILESTONES, gamma=0.2, last_epoch=checkpoint['epoch']) begin = checkpoint['epoch'] # print('test_acc :', checkpoint['test_acc'], 'train_acc :', checkpoint['train_acc']) # print('last_lr :', checkpoint['scheduler']['_last_lr']) except FileNotFoundError: # print('starting over..') begin = -1 best_acc = 0 for epoch in range(epochs): if epoch <= begin:
def main(args): np.random.seed(0) torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Constructing Model if args.resume != "": if os.path.isfile(args.resume): print("=> Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') test_only = args.test_only resume = args.resume args = checkpoint["opt"] args.test_only = test_only args.resume = resume else: checkpoint = None print("=> No checkpoint found at '{}'".format(args.resume)) model = WideResNet(args.depth, args.widen_factor, args.dropout_rate, args.num_classes) if torch.cuda.is_available(): model.cuda() model = torch.nn.DataParallel(model, device_ids=args.gpu) if args.resume != "": model.load_state_dict(checkpoint["model"]) args.start_epoch = checkpoint["epoch"] + 1 print("=> Loaded successfully '{}' (epoch {})".format( args.resume, checkpoint["epoch"])) del checkpoint torch.cuda.empty_cache() else: model.apply(conv_init) # Loading Dataset if args.augment == "meanstd": transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(Config.CIFAR10_mean, Config.CIFAR10_std), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(Config.CIFAR10_mean, Config.CIFAR10_std), ]) elif args.augment == "zac": # To Do: ZCA whitening pass else: raise NotImplementedError print("| Preparing CIFAR-10 dataset...") sys.stdout.write("| ") trainset = CIFAR10(root="./data", train=True, download=True, transform=transform_train) testset = CIFAR10(root="./data", train=False, download=False, transform=transform_test) train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=2) # Test only if args.test_only: if args.resume != "": test(args, test_loader, model) sys.exit(0) else: print("=> Test only model need to resume from a checkpoint") raise RuntimeError train(args, train_loader, test_loader, model) test(args, test_loader, model)
def experiment(): parser = argparse.ArgumentParser(description='CNN Hyperparameter Fine-tuning') parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100'], help='Choose a dataset') parser.add_argument('--model', default='resnet18', choices=['resnet18', 'wideresnet'], help='Choose a model') parser.add_argument('--num_finetune_epochs', type=int, default=200, help='Number of fine-tuning epochs') parser.add_argument('--lr', type=float, default=0.1, help='Learning rate') parser.add_argument('--optimizer', type=str, default='sgdm', help='Choose an optimizer') parser.add_argument('--batch_size', type=int, default=128, help='Mini-batch size') parser.add_argument('--data_augmentation', action='store_true', default=True, help='Whether to use data augmentation') parser.add_argument('--wdecay', type=float, default=5e-4, help='Amount of weight decay') parser.add_argument('--load_checkpoint', type=str, help='Path to pre-trained checkpoint to load and finetune') parser.add_argument('--save_dir', type=str, default='finetuned_checkpoints', help='Save directory for the fine-tuned checkpoint') args = parser.parse_args() args.load_checkpoint = '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug0.pt' if args.dataset == 'cifar10': num_classes = 10 train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True, augmentation=args.data_augmentation) elif args.dataset == 'cifar100': num_classes = 100 train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True, augmentation=args.data_augmentation) if args.model == 'resnet18': cnn = ResNet18(num_classes=num_classes) elif args.model == 'wideresnet': cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) test_id = '{}_{}_{}_lr{}_wd{}_aug{}'.format(args.dataset, args.model, args.optimizer, args.lr, args.wdecay, int(args.data_augmentation)) filename = os.path.join(args.save_dir, test_id + '.csv') csv_logger = CSVLogger( fieldnames=['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'test_loss', 'test_acc'], filename=filename) checkpoint = torch.load(args.load_checkpoint) init_epoch = checkpoint['epoch'] cnn.load_state_dict(checkpoint['model_state_dict']) model = cnn.cuda() model.train() args.hyper_train = 'augment' # 'all_weight' # 'weight' def init_hyper_train(model): """ :return: """ init_hyper = None if args.hyper_train == 'weight': init_hyper = np.sqrt(args.wdecay) model.weight_decay = Variable(torch.FloatTensor([init_hyper]).cuda(), requires_grad=True) model.weight_decay = model.weight_decay.cuda() elif args.hyper_train == 'all_weight': num_p = sum(p.numel() for p in model.parameters()) weights = np.ones(num_p) * np.sqrt(args.wdecay) model.weight_decay = Variable(torch.FloatTensor(weights).cuda(), requires_grad=True) model.weight_decay = model.weight_decay.cuda() model = model.cuda() return init_hyper if args.hyper_train == 'augment': # Dont do inside the prior function, else scope is wrong augment_net = UNet(in_channels=3, n_classes=3, depth=5, wf=6, padding=True, batch_norm=False, up_mode='upconv') # TODO(PV): Initialize UNet properly augment_net = augment_net.cuda() def get_hyper_train(): """ :return: """ if args.hyper_train == 'weight' or args.hyper_train == 'all_weight': return [model.weight_decay] if args.hyper_train == 'augment': return augment_net.parameters() def get_hyper_train_flat(): return torch.cat([p.view(-1) for p in get_hyper_train()]) # TODO: Check this size init_hyper_train(model) if args.hyper_train == 'all_weight': wdecay = 0.0 else: wdecay = args.wdecay optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.2 * 0.2, momentum=0.9, nesterov=True, weight_decay=wdecay) # args.wdecay) # print(checkpoint['optimizer_state_dict']) # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler = MultiStepLR(optimizer, milestones=[60, 120], gamma=0.2) # [60, 120, 160] hyper_optimizer = torch.optim.Adam(get_hyper_train(), lr=1e-3) # try 0.1 as lr # Set random regularization hyperparameters # data_augmentation_hparams = {} # Random values for hue, saturation, brightness, contrast, rotation, etc. if args.dataset == 'cifar10': num_classes = 10 train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True, augmentation=args.data_augmentation) elif args.dataset == 'cifar100': num_classes = 100 train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True, augmentation=args.data_augmentation) def test(loader): model.eval() # Change model to 'eval' mode (BN uses moving mean/var). correct = 0. total = 0. losses = [] for images, labels in loader: images = images.cuda() labels = labels.cuda() with torch.no_grad(): pred = model(images) xentropy_loss = F.cross_entropy(pred, labels) losses.append(xentropy_loss.item()) pred = torch.max(pred.data, 1)[1] total += labels.size(0) correct += (pred == labels).sum().item() avg_loss = float(np.mean(losses)) acc = correct / total model.train() return avg_loss, acc def prepare_data(x, y): """ :param x: :param y: :return: """ x, y = x.cuda(), y.cuda() # x, y = Variable(x), Variable(y) return x, y def train_loss_func(x, y): """ :param x: :param y: :return: """ x, y = prepare_data(x, y) reg_loss = 0.0 if args.hyper_train == 'weight': pred = model(x) xentropy_loss = F.cross_entropy(pred, y) # print(f"weight_decay: {torch.exp(model.weight_decay).shape}") for p in model.parameters(): # print(f"weight_decay: {torch.exp(model.weight_decay).shape}") # print(f"shape: {p.shape}") reg_loss = reg_loss + .5 * (model.weight_decay ** 2) * torch.sum(p ** 2) # print(f"reg_loss: {reg_loss}") elif args.hyper_train == 'all_weight': pred = model(x) xentropy_loss = F.cross_entropy(pred, y) count = 0 for p in model.parameters(): reg_loss = reg_loss + .5 * torch.sum( (model.weight_decay[count: count + p.numel()] ** 2) * torch.flatten(p ** 2)) count += p.numel() elif args.hyper_train == 'augment': augmented_x = augment_net(x) pred = model(augmented_x) xentropy_loss = F.cross_entropy(pred, y) return xentropy_loss + reg_loss, pred def val_loss_func(x, y): """ :param x: :param y: :return: """ x, y = prepare_data(x, y) pred = model(x) xentropy_loss = F.cross_entropy(pred, y) return xentropy_loss for epoch in range(init_epoch, init_epoch + args.num_finetune_epochs): xentropy_loss_avg = 0. total_val_loss = 0. correct = 0. total = 0. progress_bar = tqdm(train_loader) for i, (images, labels) in enumerate(progress_bar): progress_bar.set_description('Finetune Epoch ' + str(epoch)) # TODO: Take a hyperparameter step here optimizer.zero_grad(), hyper_optimizer.zero_grad() val_loss, weight_norm, grad_norm = hyper_step(1, 1, get_hyper_train, get_hyper_train_flat, model, val_loss_func, val_loader, train_loss_func, train_loader, hyper_optimizer) # del val_loss # print(f"hyper: {get_hyper_train()}") images, labels = images.cuda(), labels.cuda() # pred = model(images) # xentropy_loss = F.cross_entropy(pred, labels) xentropy_loss, pred = train_loss_func(images, labels) optimizer.zero_grad(), hyper_optimizer.zero_grad() xentropy_loss.backward() optimizer.step() xentropy_loss_avg += xentropy_loss.item() # Calculate running average of accuracy pred = torch.max(pred.data, 1)[1] total += labels.size(0) correct += (pred == labels.data).sum().item() accuracy = correct / total progress_bar.set_postfix( train='%.5f' % (xentropy_loss_avg / (i + 1)), val='%.4f' % (total_val_loss / (i + 1)), acc='%.4f' % accuracy, weight='%.2f' % weight_norm, update='%.3f' % grad_norm) val_loss, val_acc = test(val_loader) test_loss, test_acc = test(test_loader) tqdm.write('val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'.format( val_loss, val_acc, test_loss, test_acc)) scheduler.step(epoch) row = {'epoch': str(epoch), 'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy), 'val_loss': str(val_loss), 'val_acc': str(val_acc), 'test_loss': str(test_loss), 'test_acc': str(test_acc)} csv_logger.writerow(row)