if args.model == 'resnet18': cnn = ResNet18(num_classes=num_classes) elif args.model == 'wideresnet': if args.dataset == 'svhn': cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8, dropRate=0.4) else: cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3) cnn = cnn.cuda() criterion = nn.CrossEntropyLoss().cuda() cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate, momentum=0.9, nesterov=True, weight_decay=5e-4) if args.dataset == 'svhn': scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1) else: scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2) filename = 'logs/' + test_id + '.csv'
def run_cutout(dataset="cifar10", model="resnet18", epochs=200, batch_size=128, learning_rate=0.1, data_augmentation=False, cutout=False, n_holes=1, length=8, no_cuda=False, seed=0): cuda = not no_cuda and torch.cuda.is_available() cudnn.benchmark = True # Should make training should go faster for large models torch.manual_seed(seed) if cuda: torch.cuda.manual_seed(seed) test_id = dataset + '_' + model # Image Preprocessing if dataset == 'svhn': normalize = transforms.Normalize( mean=[x / 255.0 for x in [109.9, 109.7, 113.8]], std=[x / 255.0 for x in [50.1, 50.6, 50.8]]) else: normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) train_transform = transforms.Compose([]) if data_augmentation: train_transform.transforms.append(transforms.RandomCrop(32, padding=4)) train_transform.transforms.append(transforms.RandomHorizontalFlip()) train_transform.transforms.append(transforms.ToTensor()) train_transform.transforms.append(normalize) if cutout: train_transform.transforms.append( Cutout(n_holes=n_holes, length=length)) test_transform = transforms.Compose([transforms.ToTensor(), normalize]) if dataset == 'cifar10': num_classes = 10 train_dataset = datasets.CIFAR10(root='data/', train=True, transform=train_transform, download=True) test_dataset = datasets.CIFAR10(root='data/', train=False, transform=test_transform, download=True) elif dataset == 'cifar100': num_classes = 100 train_dataset = datasets.CIFAR100(root='data/', train=True, transform=train_transform, download=True) test_dataset = datasets.CIFAR100(root='data/', train=False, transform=test_transform, download=True) elif dataset == 'svhn': num_classes = 10 train_dataset = datasets.SVHN(root='data/', split='train', transform=train_transform, download=True) extra_dataset = datasets.SVHN(root='data/', split='extra', transform=train_transform, download=True) # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf) data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0) labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0) train_dataset.data = data train_dataset.labels = labels test_dataset = datasets.SVHN(root='data/', split='test', transform=test_transform, download=True) # Data Loader (Input Pipeline) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2) if model == 'resnet18': cnn = ResNet18(num_classes=num_classes) elif model == 'wideresnet': if dataset == 'svhn': cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8, dropRate=0.4) else: cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3) cnn = cnn.cuda() criterion = nn.CrossEntropyLoss().cuda() cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=5e-4) if dataset == 'svhn': scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1) else: scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2) #TODO: change path to relative path filename = "/beegfs/work/workspace/ws/fr_mn119-augment-0/logs/{}.csv".format( test_id) # filename = 'logs/' + test_id + '.csv' args = argparse.Namespace( **{ "dataset": dataset, "model": model, "epochs": epochs, "batch_size": batch_size, "learning_rate": learning_rate, "data_augmentation": data_augmentation, "cutout": cutout, "n_holes": n_holes, "length": length, "no_cuda": no_cuda, "seed": seed }) csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc'], filename=filename) def test(loader): cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). correct = 0. total = 0. for images, labels in loader: if dataset == 'svhn': # SVHN labels are from 1 to 10, not 0 to 9, so subtract 1 labels = labels.type_as(torch.LongTensor()).view(-1) - 1 images = Variable(images, volatile=True).cuda() labels = Variable(labels, volatile=True).cuda() pred = cnn(images) pred = torch.max(pred.data, 1)[1] total += labels.size(0) correct += (pred == labels.data).sum() val_acc = correct / total cnn.train() return val_acc for epoch in range(epochs): xentropy_loss_avg = 0. correct = 0. total = 0. progress_bar = tqdm(train_loader) for i, (images, labels) in enumerate(progress_bar): progress_bar.set_description('Epoch ' + str(epoch)) if dataset == 'svhn': # SVHN labels are from 1 to 10, not 0 to 9, so subtract 1 labels = labels.type_as(torch.LongTensor()).view(-1) - 1 images = Variable(images).cuda(async=True) labels = Variable(labels).cuda(async=True) cnn.zero_grad() pred = cnn(images) xentropy_loss = criterion(pred, labels) xentropy_loss.backward() cnn_optimizer.step() xentropy_loss_avg += xentropy_loss.data[0] # Calculate running average of accuracy _, pred = torch.max(pred.data, 1) total += labels.size(0) correct += (pred == labels.data).sum() accuracy = correct / total progress_bar.set_postfix(xentropy='%.3f' % (xentropy_loss_avg / (i + 1)), acc='%.3f' % accuracy) test_acc = test(test_loader) tqdm.write('test_acc: %.3f' % (test_acc)) scheduler.step(epoch) row = { 'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc) } csv_logger.writerow(row) # torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt') csv_logger.close() results = { 'epoch': epoch, 'train_error': 1 - accuracy, 'test_error': 1 - test_acc } # validation error for hyperband return results