def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=0, help='random seed (default: 1)') parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=200, help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.01)') parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], help='type of data augmentation') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() # Set seed torch.manual_seed(args.seed) torch.backends.cudnn.benchmark = False np.random.seed(args.seed) device = torch.device("cuda") kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} # Load supervised training if args.da == 'rotate': mnist_0 = MnistRotatedDist('../dataset/', train=True, thetas=[0.0], d_label=0, transform=True) mnist_30 = MnistRotatedDist('../dataset/', train=True, thetas=[30.0], d_label=1, transform=True) mnist_90 = MnistRotatedDist('../dataset/', train=True, thetas=[90.0], d_label=2, transform=True) model_name = 'baseline_test_0_random_rotate_seed_' + str(args.seed) elif args.da == 'flip': mnist_0 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[0.0], d_label=0) mnist_30 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[30.0], d_label=1) mnist_90 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[90.0], d_label=2) model_name = 'baseline_test_0_random_flips_seed_' + str(args.seed) mnist = data_utils.ConcatDataset([mnist_0, mnist_30, mnist_90]) train_size = int(0.9 * len(mnist)) val_size = len(mnist) - train_size train_dataset, val_dataset = torch.utils.data.random_split( mnist, [train_size, val_size]) train_loader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = data_utils.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) model = Net().to(device) # model = NetFlat().to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) best_val_acc = 0 for epoch in range(1, args.epochs + 1): print('\n Epoch: ' + str(epoch)) train(args, model, device, train_loader, optimizer, epoch) val_loss, val_acc = test(args, model, device, val_loader) print(epoch, val_loss, val_acc) # Save best if val_acc >= best_val_acc: best_val_acc = val_acc torch.save(model, model_name + '.model') torch.save(args, model_name + '.config') # Test loader mnist_60 = MnistRotated('../dataset/', train=False, thetas=[60.0], d_label=0) test_loader = data_utils.DataLoader(mnist_60, batch_size=args.batch_size, shuffle=False, **kwargs) model = torch.load(model_name + '.model').to(device) _, test_acc = test(args, model, device, test_loader) with open(model_name + '.txt', "w") as text_file: text_file.write("Test Acc: " + str(test_acc))
def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=0, help='random seed (default: 1)') parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.01)') parser.add_argument('--da', type=str, default='scale', choices=[ 'brightness', 'contrast', 'saturation', 'hue', 'rotation', 'translate', 'scale', 'shear', 'vflip', 'hflip', 'none', ]) parser.add_argument( '-dd', '--data_dir', type=str, default='./data', help='Directory to download data to and load data from') parser.add_argument( '-wd', '--wandb_dir', type=str, default='./', help= '(OVERRIDDEN BY ENV_VAR for sweep) Directory to download data to and load data from' ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() print(args.da) # Set seed torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device("cuda") kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} transform_dict = { 'brightness': torchvision.transforms.ColorJitter(brightness=1.0, contrast=0, saturation=0, hue=0), 'contrast': torchvision.transforms.ColorJitter(brightness=0, contrast=10.0, saturation=0, hue=0), 'saturation': torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=10.0, hue=0), 'hue': torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5), 'rotation': torchvision.transforms.RandomAffine([0, 359], translate=None, scale=None, shear=None, resample=PIL.Image.BILINEAR, fillcolor=0), 'translate': torchvision.transforms.RandomAffine(0, translate=[0.2, 0.2], scale=None, shear=None, resample=PIL.Image.BILINEAR, fillcolor=0), 'scale': torchvision.transforms.RandomAffine(0, translate=None, scale=[0.8, 1.2], shear=None, resample=PIL.Image.BILINEAR, fillcolor=0), 'shear': torchvision.transforms.RandomAffine(0, translate=None, scale=None, shear=[-10., 10., -10., 10.], resample=PIL.Image.BILINEAR, fillcolor=0), 'vflip': torchvision.transforms.RandomVerticalFlip(p=0.5), 'hflip': torchvision.transforms.RandomHorizontalFlip(p=0.5), 'none': None, } rng_state = np.random.get_state() mnist_0_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[0], d_label=0, transform=transform_dict[args.da], rng_state=rng_state) mnist_0_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[0], d_label=0, transform=None, rng_state=rng_state) rng_state = np.random.get_state() mnist_30_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[30.0], d_label=1, transform=transform_dict[args.da], rng_state=rng_state) mnist_30_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[30.0], d_label=1, transform=None, rng_state=rng_state) rng_state = np.random.get_state() mnist_60_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[60.0], d_label=2, transform=transform_dict[args.da], rng_state=rng_state) mnist_60_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[60.0], d_label=2, transform=None, rng_state=rng_state) mnist_train = data_utils.ConcatDataset( [mnist_0_train, mnist_30_train, mnist_60_train]) train_loader = data_utils.DataLoader(mnist_train, batch_size=100, shuffle=True, **kwargs) mnist_val = data_utils.ConcatDataset( [mnist_0_val, mnist_30_val, mnist_60_val]) val_loader = data_utils.DataLoader(mnist_val, batch_size=100, shuffle=True, **kwargs) wandb.init(project="NewRotated_MNIST", config=args, name=args.da) model_name = 'baseline_test_0_' + args.da + '_seed_' + str(args.seed) model = Net().to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) train_accs = [] val_accs = [] for epoch in range(1, args.epochs + 1): print('\n Epoch: ' + str(epoch)) train(args, model, device, train_loader, optimizer, epoch) train_loss, train_acc = test(args, model, device, train_loader) val_loss, val_acc = test(args, model, device, val_loader) print(train_acc, val_acc) wandb.log({'train accuracy': train_acc, 'val accuracy': val_acc}) train_accs.append(train_acc) val_accs.append(val_acc) train_accs = np.array(train_accs) mean_train_accs = np.mean(train_accs[-10:]) print(mean_train_accs) val_accs = np.array(val_accs) mean_val_accs = np.mean(val_accs[-10:]) print(mean_val_accs) with open(model_name + '.txt', "w") as text_file: text_file.write("Mean train acc: " + str(mean_train_accs)) text_file.write("Mean val acc: " + str(mean_val_accs)) wandb.run.summary["mean_train_accs"] = mean_train_accs wandb.run.summary["mean_val_accs"] = mean_val_accs