download=True) test_dataset = torchvision.datasets.CIFAR10(root='./data', transform=transform_test, train=False, download=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, num_workers=4, shuffle=False) if args.network == 'sqnxt': net = SqNxt_23_1x(10, ODEBlock) elif args.network == 'resnet': net = ResNet18(ODEBlock) net.apply(conv_init) print(net) if is_use_cuda: net.cuda() #to(device) net = nn.DataParallel(net) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# train_dataset = torchvision.datasets.CIFAR10(root='../data', transform = transform_train, train = True, download = True) # test_dataset = torchvision.datasets.CIFAR10(root='../data', transform = transform_test, train = False, download = True) # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, num_workers = 4, shuffle = True) # test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 128, num_workers = 4, shuffle = False) train_loader, test_loader, train_dataset = get_galaxyZoo_loaders( batch_size=args.batch_size, test_batch_size=args.test_batch_size) if args.dataset == 'MTVSO': num_classes = 20 else: num_classes = 10 if args.network == 'sqnxt': net = SqNxt_23_1x(num_classes, ODEBlock) elif args.network == 'resnet': net = ResNet18(ODEBlock, num_classes=num_classes) net.apply(conv_init) print(net) if is_use_cuda: net.to(device) net = nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)