def run(args: DictConfig) -> None: args.device = args.device if torch.cuda.is_available() else "cpu" if args.dataset == "cifar10": train_dataset = datasets.CIFAR10(root=args.data_dir, download=True, transform=TransformsSimCLR()) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers, ) model, optimizer, scheduler = load_model(args) mask = mask_correlated_samples(args.batch_size) criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device) for epoch in range(args.epochs): train_loss = train_epoch(args, train_loader, model, criterion, optimizer, scheduler=scheduler) logger.info('Epoch {}, train_loss: {:.4f}'.format(epoch, train_loss)) if epoch % 10 == 9: save_model(args, model, epoch + 1)
def run(args: DictConfig) -> None: args.device = args.device if torch.cuda.is_available() else "cpu" transform = transforms.Compose([ transforms.ToTensor(), #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == "STL10": train_dataset = torchvision.datasets.STL10( root=args.data_dir, split="train", download=True, transform=torchvision.transforms.ToTensor()) test_dataset = torchvision.datasets.STL10( root=args.data_dir, split="test", download=True, transform=torchvision.transforms.ToTensor()) elif args.dataset == "cifar10": train_dataset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=True, transform=transform) test_dataset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform) else: raise NotImplementedError train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.logistic_batch_size, shuffle=True, drop_last=True, num_workers=args.workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.logistic_batch_size, shuffle=False, drop_last=True, num_workers=args.workers, ) simclr_model, _, _ = load_model(args) simclr_model = simclr_model.to(args.device) simclr_model.eval() n_classes = 10 # stl-10 model = LogisticRegression(simclr_model.n_features, n_classes).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = torch.nn.CrossEntropyLoss() for epoch in range(args.epochs): loss, acc = run_epoch(args, train_loader, simclr_model, model, criterion, optimizer) print('Epoch {}, loss: {:.4f}, acc: {:.4f}'.format(epoch, loss, acc)) # final testing loss, acc = run_epoch(args, test_loader, simclr_model, model, criterion) print('Test loss: {:.4f}, acc: {:.4f}'.format(loss, acc))