Esempio n. 1
0
    args = args_parser()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),  # normalize to [0, 1]
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]) 
    if args.imshow == True:
        train_dataset = selfData(args.train_img, args.train_lab, transforms)
        train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True, num_workers = 0, drop_last= False)
        imgs, labels = train_loader.__iter__().__next__()
        imshow(train_loader)

    if args.model == 'mAlexNet':
        net = mAlexNet().to(device)
    elif args.model == 'AlexNet':
        net = AlexNet().to(device)

    criterion = nn.CrossEntropyLoss()
    if args.path == '':
        train(args.epochs, args.train_img, args.train_lab, transforms, net, criterion)
        PATH = './model.pth'
        torch.save(net.state_dict(), PATH)
        if args.model == 'mAlexNet':
            net = mAlexNet().to(device)
        elif args.model == 'AlexNet':
            net = AlexNet().to(device)
        net.load_state_dict(torch.load(PATH))
    else:
        PATH = args.path
Esempio n. 2
0
                        if result:
                            if label=='1':
                                full_img=cv2.rectangle(full_img,(x,y),(x+w,y+w),color=(0,255,0))#green
                            else:
                                full_img = cv2.rectangle(full_img, (x, y), (x + w, y + w), color=(0,0 , 255))#red
                        else:
                            full_img = cv2.rectangle(full_img, (x, y), (x + w, y + w), color=(255, 0, 0))#blue
                targ_path=osp.join(save_path,cameras+"_"+repr(count)+'.jpg')
                print("save image:",cameras+"_"+repr(count)+'.jpg')
                cv2.imwrite(targ_path,full_img)
                # cv2.imshow("test",full_img)
                # cv2.waitKey(0)


if __name__=="__main__":
    test_img=selfData(args.test_img,args.test_lab,slot_id=True)
    save_path=args.save_path+args.net

    if args.net=="carnet":
        net=carNet()
    elif args.net=="mAlexNet":
        net=mAlexNet()
    net.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(args.weight_path, map_location="cpu").items()})
    if torch.cuda.is_available():
        net.cuda(args.cuda_device)
    for index in range(1,10):
        str="camera"+repr(index)
        print("camera:", index)
        solve(str,net,test_img,save_path)