def get_model(device=None): # 加载CNN模型 model = AlexNet(num_classes=2) model.load_state_dict( torch.load('./models/best_linear_svm_alexnet_car.pth')) model.eval() # 取消梯度追踪 for param in model.parameters(): param.requires_grad = False if device: model = model.to(device) return model
def train(): torch.multiprocessing.freeze_support() traindir = os.path.join('./200508_cat_classification/dogs-vs-cats', 'train') #경로를 병합함 . testdir = os.path.join('./200508_cat_classification/dogs-vs-cats', 'test') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = datautil.DataLoader(TrainImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])), batch_size=4, shuffle=True, num_workers=4, pin_memory=True) test_loader = datautil.DataLoader(TestImageFolder( testdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=1, shuffle=False, num_workers=1, pin_memory=False) net = AlexNet() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = net.to(device) load_model(net, './alexnet.pth') if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net) if torch.cuda.is_available(): net.cuda() import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.004) for epoch in range(3): running_loss = 0.0 acc = 0. correct = 0 for i, data in enumerate(train_loader, 0): inputs, labels = data inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) optimizer.zero_grad() outputs = net(inputs) #print(outputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.data.item() prediction = torch.max(outputs.data, 1)[1] correct += prediction.eq( labels.data.view_as(prediction)).cpu().sum() if i % 2000 == 1999: total = (i + 1) * 4 print( f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.6f} acc : {correct} / {total}' ) running_loss = 0.0 print('Finished Training') save_model(net, './')