Exemplo n.º 1
0
def load_model_vqa(map_location):
    global ques_ix
    global net_vqa
    global max_token
    global token_to_ix
    global ix_to_ans
    model_path = 'CKPT/epoch19.pkl'
    train_path = 'data/CLEVR_train_questions.json'
    print("Load model...")
    state_dict = torch.load(model_path,
                            map_location=map_location)['state_dict']
    ques_stat = json.load(open(train_path, 'r'))['questions']
    stat_ans = json.load(open(train_path, 'r'))['questions']

    token_to_ix, pretrained_emb, max_token = tokenize(ques_stat, False)
    ans_to_ix, ix_to_ans = ans_stat(stat_ans)

    ans_size = ans_to_ix.__len__()
    token_size = token_to_ix.__len__()

    print("token_size:", token_size)
    print("ans_size:", ans_size)
    net_vqa = Net(pretrained_emb, token_size, ans_size)
    net_vqa.load_state_dict(state_dict)
    net_vqa.eval()
Exemplo n.º 2
0
def predict():
    model = Net(model_name).to(device)
    model_save_path = os.path.join(config.model_path, '{}.bin'.format(model_name))
    model.load_state_dict(torch.load(model_save_path))

#    data_len = len(os.listdir(config.image_test_path))
#    test_path_list = ['{}/{}.jpg'.format(config.image_test_path, x) for x in range(0, data_len)]
#    test_data = np.array(test_path_list)
    test_df = pd.read_csv(config.test_path)
    test_df['FileID'] = test_df['FileID'].apply(lambda x: '{}/{}.jpg'.format(config.image_test_path, x))
    print('test:{}'.format(test_df.shape[0]))
    test_dataset = MyDataset(test_df, test_transform, 'test')
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

    model.eval()
    pred_list = []
    with torch.no_grad():
        for batch_x, _ in tqdm(test_loader):
            batch_x = batch_x.to(device)
            # compute output
            probs = model(batch_x)
            preds = torch.argmax(probs, dim=1)
            pred_list += [p.item() for p in preds]

    submission = pd.DataFrame({"FileID": range(len(pred_list)), "SpeciesID": pred_list})
    submission.to_csv('submission.csv', index=False, header=False)
Exemplo n.º 3
0
def visualize(model: str, images: [str], occlusion_window: int,
              occlusion_stride: int, no_occlussion: bool, no_gradient: bool):
    device = torch.device('cuda:0' if cuda.is_available() else 'cpu')
    click.secho('Using device={}'.format(device), fg='blue')

    net = Net()
    net.to(device)

    click.secho('Loading model from \'{}\''.format(model), fg='yellow')
    net.load_state_dict(torch.load(model, map_location=device))
    net.eval()

    for path in images:
        image = utils.load_image(path).to(device)
        output = net(image)
        _, predicted = torch.max(output.data, 1)
        click.echo('Image \'{}\' most likely represents a \'{}\''.format(
            path, classes[predicted]))
        if not no_occlussion:
            occlustion(net,
                       image,
                       predicted,
                       k=occlusion_window,
                       stride=occlusion_stride)
        if not no_gradient:
            gradient(net, image, predicted)
Exemplo n.º 4
0
def train(load_model: str, save_model: str, train_dataset: str,
          test_dataset: str, no_train: bool, no_test: bool, epochs: int,
          batch_size: int, learning_rate: float):
    device = torch.device('cuda:0' if cuda.is_available() else 'cpu')
    click.secho('Using device={}'.format(device), fg='blue')

    net = Net()
    net.to(device)

    if load_model is not None:
        click.secho('Loading model from \'{}\''.format(load_model),
                    fg='yellow')
        net.load_state_dict(torch.load(load_model, map_location=device))

    if not no_train:
        click.echo('Training model using {}'.format(train_dataset))
        net.train()
        train_net(net,
                  data_path=train_dataset,
                  batch_size=batch_size,
                  num_epochs=epochs,
                  learning_rate=learning_rate)

    if not no_train and save_model is not None:
        click.secho('Saving model as \'{}\''.format(save_model), fg='yellow')
        torch.save(net.state_dict(), save_model)

    if not no_test:
        click.echo('Testing model using {}'.format(test_dataset))
        net.eval()
        accuracy = test_net(net, data_path=test_dataset, batch_size=batch_size)
        color = 'green' if accuracy > 97. else 'red'
        click.secho('Accuracy={}'.format(accuracy), fg=color)
Exemplo n.º 5
0
def load_model_vqa(map_location):
    global net_vqa
    global token_to_ix
    global ix_to_ans

    # Base path
    model_path = './model/muan.pkl'
    token_to_ix_path = './data/token_to_ix.json'
    pretrained_emb_path = './data/pretrained_emb.npy'
    ans_dict_path = './data/vqa_answer_dict.json'
    '''Pre-load'''
    # Load token_to_ix
    token_to_ix = json.load(open(token_to_ix_path, 'r'))
    token_size = len(token_to_ix)
    print(' ========== Question token vocab size:', token_size)
    # print('token_to_ix:', token_to_ix)

    # Load pretrained_emb
    pretrained_emb = np.load(pretrained_emb_path)
    print('pretrained_emb shape:', pretrained_emb.shape)

    # Answers statistic
    ans_to_ix, ix_to_ans = json.load(open(ans_dict_path, 'r'))
    ans_size = len(ans_to_ix)
    print(
        ' ========== Answer token vocab size (occur more than {} times):'.
        format(8), ans_size)
    # print('ix_to_ans:\n', ix_to_ans)
    '''Load the pre-trained model'''
    # Load model ckpt
    time_start = time.time()
    print('\nLoading ckpt from: {}'.format(model_path))
    state_dict = torch.load(model_path,
                            map_location=map_location)['state_dict']
    print('state_dict num:', len(state_dict.keys()))
    print('Finish load state_dict!')

    # Load model
    net_vqa = Net(pretrained_emb, token_size, ans_size)
    net_vqa.load_state_dict(state_dict)
    net_vqa.cuda()
    net_vqa.eval()
    # del state_dict
    # print('net:', net)
    time_end = time.time()
    print('Finish load net model!')
    print('Model load time: {:.3f}s\n'.format(time_end - time_start))
Exemplo n.º 6
0
    time_end = time.time()
    print('Image feature process time: {:.3f}s'.format(time_end - time_start))

    # Load model ckpt
    print('\nLoading ckpt from: {}'.format(model_path))
    time_start = time.time()
    state_dict = torch.load(model_path,
                            map_location=map_location)['state_dict']
    print('state_dict num:', len(state_dict.keys()))
    print('Finish load state_dict!')

    # Load model
    net = Net(pretrained_emb, token_size, ans_size)
    net.cuda()
    net.eval()
    net.load_state_dict(state_dict)
    # print('net:', net)
    time_end = time.time()
    print('Finish load net model!')
    print('Model load time: {:.3f}s\n'.format(time_end - time_start))

    # Predict
    time_start = time.time()
    pred = net(imgfeat_batch, bboxfeat_batch, quesix_batch)
    pred_np = pred.cpu().data.numpy()
    pred_argmax = np.argmax(pred_np, axis=1)[0]
    pred_ans = ix_to_ans[str(pred_argmax)]
    print('pred_argmax:', pred_argmax)
    print('pred_ans:', pred_ans)
    if language in ['ZH', 'zh']:
Exemplo n.º 7
0
def main():
    # Init running label
    sayToMe('I am your gesture assistant. Speed : 1 terahertz, memory : 1 zigabyte.')
    sayToMe('Perform one of the below gestures')
    print("I am your gesture assistant...")
    print("############################################################################")
    print("Show the numbers below to the camera to begin.....")
    print("############################################################################")
    print("1) Open Gmail")
    print("2) Start Video Player")
    print("3) Start Microsoft Word")
    print("4) Open Play Music")
    print("5) Open Devpost.com")
    print("############################################################################")

    
    open_facebook = False;
    open_two = False;
    open_three = False;
    open_four = False;
    open_five = False;
    open_zero = False;
    
    
    one_count=0
    two_count=0
    three_count=0
    four_count=0
    five_count=0
    


    last_10_detection = np.zeros(10)

    # Init Video Capture
    global hand_hist, in_data
    is_hand_hist_created = False
    capture = cv2.VideoCapture(0)

    # Init Model
    model = Net()
    model.load_state_dict(torch.load('./model/model_sl_3968.pt', map_location=lambda storage, location: storage))
    model.eval()

   # step = 0
    detection_result = 'None'

    while capture.isOpened():
        pressed_key = cv2.waitKey(1)
        _, frame = capture.read()
        #cv2.rectangle(frame,(100,100),(300,300),(0,255,0),0)    


        # Start / Stop Detection when 'z' pressed
        if pressed_key & 0xFF == ord('z'):
            if is_hand_hist_created:
                is_hand_hist_created = False
            else:
                is_hand_hist_created = True
                hand_hist = hand_histogram(frame)

                # Reinit running label
                last_10_detection = np.zeros(10)

        if is_hand_hist_created:
            frame = manage_image_opr(frame, hand_hist)
        else:
            frame = draw_rect(frame)

        # Perform Detection
        if in_data is not None and is_hand_hist_created:
            g_img = cv2.cvtColor(in_data, cv2.COLOR_BGR2GRAY)
            x = torch.FloatTensor(g_img).view(1,1,64,64) / 255
            x = (x - 0.5) / 0.5

            with torch.no_grad():
                y = model(x)
                y_idx = F.softmax(y, dim=-1).argmax().numpy()

                # Update likelihood
                last_10_detection[y_idx] += 2
                last_10_detection = last_10_detection - 1
                last_10_detection = np.clip(last_10_detection, 0, 8)

                # print(y_idx, label_dict[int(y_idx)], F.softmax(y, dim=-1))
                detection_result = label_dict[int(np.argmax(last_10_detection))]
                
                if detection_result == "ONE":
                    one_count=one_count+1
                    two_count=0
                    three_count=0
                    four_count=0
                    five_count=0
                    
                    print(one_count)
                    #print(str(one_count).join(" "), end='')

                    if(one_count==10):
                        sayToMe('Opening gmail')
                        #print("One count reached")
                        url = 'https://www.gmail.com/'
                        webbrowser.open(url)
                        one_count=0
                
                if detection_result == "TWO":
                    one_count=0
                    two_count=two_count+1
                    three_count=0
                    four_count=0
                    five_count=0
                    
                    print(two_count)
                    
                    if(two_count==10):
                        #print("Two count reached")
                        sayToMe('Starting Video Player')
                        os.system("start D:/Devpost_Hackathons/pytorch/test_video1.mp4")
                        two_count=0
                        
                if detection_result == "THREE":
                    one_count=0
                    two_count=0
                    three_count=three_count+1
                    four_count=0
                    five_count=0
                    
                    print(three_count)
                    
                    if(three_count==10):
                        #print("Three count reached")
                        sayToMe('Starting Microsoft word')
                        os.system("start winword")
                        three_count=0
                        
                if detection_result == "FOUR":
                    one_count=0
                    two_count=0
                    three_count=0
                    four_count=four_count+1
                    five_count=0
                    
                    print(four_count)
                    
                    if(four_count==10):
                        #print("Four count reached")
                        sayToMe('Opening Play Music')
                        #print("One count reached")
                        url = 'https://play.google.com/music/listen?u=0#/home'
                        webbrowser.open(url)
                        four_count=0
                        
                if detection_result == "FIVE":
                    one_count=0
                    two_count=0
                    three_count=0
                    four_count=0
                    five_count=five_count+1
                    
                    print(five_count)
                    
                    if(five_count==10):
                        #print("Five count reached")
                        sayToMe('Opening Devpost')
                        #print("One count reached")
                        url = 'https://devpost.com'
                        webbrowser.open(url)
                        five_count=0
                
                    
        else:
            detection_result = 'None'

        # Render to Screen
        if is_hand_hist_created :
            if in_data is not None:
                frame[:75,:180,:] = 0
                frame[:64,-64:,:] = np.expand_dims(cv2.cvtColor(in_data, cv2.COLOR_BGR2GRAY), axis=-1)
                cv2.putText(frame,'DETECTED',(5,30), cv2.FONT_HERSHEY_DUPLEX, 1, (255,255,255), 1, cv2.LINE_AA)
                cv2.putText(frame,'{}'.format(detection_result),(5,65), cv2.FONT_HERSHEY_DUPLEX, 1, (255,255,255), 1, cv2.LINE_AA)

        cv2.imshow("FunTorch", rescale_frame(frame))

        # Close if ESC pressed
        if pressed_key == 27:
            break

    cv2.destroyAllWindows()
    capture.release()
Exemplo n.º 8
0
def main():
    # Init running label
    last_10_detection = np.zeros(10)

    # Init Video Capture
    global hand_hist, in_data
    is_hand_hist_created = False
    capture = cv2.VideoCapture(0)

    # Init Model
    model = Net()
    model.load_state_dict(
        torch.load('./model/model_sl_3968.pt',
                   map_location=lambda storage, location: storage))
    model.eval()

    step = 0
    detection_result = 'None'

    while capture.isOpened():
        pressed_key = cv2.waitKey(1)
        _, frame = capture.read()

        # Start / Stop Detection when 'z' pressed
        if pressed_key & 0xFF == ord('z'):
            if is_hand_hist_created:
                is_hand_hist_created = False
            else:
                is_hand_hist_created = True
                hand_hist = hand_histogram(frame)

                # Reinit running label
                last_10_detection = np.zeros(10)

        if is_hand_hist_created:
            frame = manage_image_opr(frame, hand_hist)
        else:
            frame = draw_rect(frame)

        # Perform Detection
        if in_data is not None and is_hand_hist_created:
            g_img = cv2.cvtColor(in_data, cv2.COLOR_BGR2GRAY)
            x = torch.FloatTensor(g_img).view(1, 1, 64, 64) / 255
            x = (x - 0.5) / 0.5

            with torch.no_grad():
                y = model(x)
                y_idx = F.softmax(y, dim=-1).argmax().numpy()

                # Update likelihood
                last_10_detection[y_idx] += 2
                last_10_detection = last_10_detection - 1
                last_10_detection = np.clip(last_10_detection, 0, 8)

                # print(y_idx, label_dict[int(y_idx)], F.softmax(y, dim=-1))
                detection_result = label_dict[int(
                    np.argmax(last_10_detection))]
        else:
            detection_result = 'None'

        # Render to Screen
        if is_hand_hist_created:
            if in_data is not None:
                frame[:75, :180, :] = 0
                frame[:64, -64:, :] = np.expand_dims(cv2.cvtColor(
                    in_data, cv2.COLOR_BGR2GRAY),
                                                     axis=-1)
                cv2.putText(frame, 'DETECTED', (5, 30),
                            cv2.FONT_HERSHEY_DUPLEX, 1, (255, 255, 255), 1,
                            cv2.LINE_AA)
                cv2.putText(frame, '{}'.format(detection_result), (5, 65),
                            cv2.FONT_HERSHEY_DUPLEX, 1, (255, 255, 255), 1,
                            cv2.LINE_AA)

        cv2.imshow("FunTorch", rescale_frame(frame))

        # Close if ESC pressed
        if pressed_key == 27:
            break

    cv2.destroyAllWindows()
    capture.release()