Exemplo n.º 1
0
    ]) 
    if args.imshow == True:
        train_dataset = selfData(args.train_img, args.train_lab, transforms)
        train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True, num_workers = 0, drop_last= False)
        imgs, labels = train_loader.__iter__().__next__()
        imshow(train_loader)

    if args.model == 'mAlexNet':
        net = mAlexNet().to(device)
    elif args.model == 'AlexNet':
        net = AlexNet().to(device)

    criterion = nn.CrossEntropyLoss()
    if args.path == '':
        train(args.epochs, args.train_img, args.train_lab, transforms, net, criterion)
        PATH = './model.pth'
        torch.save(net.state_dict(), PATH)
        if args.model == 'mAlexNet':
            net = mAlexNet().to(device)
        elif args.model == 'AlexNet':
            net = AlexNet().to(device)
        net.load_state_dict(torch.load(PATH))
    else:
        PATH = args.path
        if args.model == 'mAlexNet':
            net = mAlexNet().to(device)
        elif args.model == 'AlexNet':
            net = AlexNet().to(device)
        net.load_state_dict(torch.load(PATH))
    accuracy = test(args.test_img, args.test_lab, transforms, net)
    print("\nThe accuracy of training on '{}' and testing on '{}' is {:.3f}.".format(args.train_lab.split('.')[0], args.test_lab.split('.')[0], accuracy))
Exemplo n.º 2
0

if __name__ == "__main__":
    if args.model == 'mAlexNet':
        net = mAlexNet()
    elif args.model == 'AlexNet':
        net = AlexNet()
    elif args.model == "carnet":
        net = carNet()
    elif args.model == 'stn_shuf':
        net = stn_shufflenet()
    elif args.model == 'stn_trans_shuf':
        net = stn_trans_shufflenet()
    elif args.model == 'shuf':
        net = torchvision.models.shufflenet_v2_x1_0(pretrained=False,
                                                    num_classes=2)
    elif args.model == 'trans_shuf':
        net = trans_shufflenet()
    torch.set_default_tensor_type('torch.FloatTensor')
    print("test net:carNet..")
    # print({k.replace('module.',''):v for k,v in torch.load(args.path,map_location="cpu").items()})
    # exit()
    net.load_state_dict({
        k.replace('module.', ''): v
        for k, v in torch.load(args.path, map_location="cpu").items()
    })
    if torch.cuda.is_available():
        net.cuda(int(args.cuda_device))
    # exit()
    acc = test(args.test_img, args.test_lab, net)
    print(args.test_lab, acc)