Exemplo n.º 1
0
def eval(rank,
         args,
         shared_nav_model,
         shared_ans_model,
         best_eval_acc=0,
         epoch=0):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        nav_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    model_kwargs = {'vocab': load_vocab(args.vocab_json)}
    ans_model = VqaLstmCnnAttentionModel(**model_kwargs)

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,
        'input_type': args.model_type,
        'num_frames': 5,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': False
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_nav_log_path = os.path.join(args.log_dir,
                                            'nav_eval_' + str(rank) + '.json')
    args.output_ans_log_path = os.path.join(args.log_dir,
                                            'ans_eval_' + str(rank) + '.json')

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):  #Himi changes
        print('###############################')
        print('[eval] Epoch is:', epoch)
        print('###############################')

        start_time = time.time()
        invalids = []

        nav_model.load_state_dict(shared_nav_model.state_dict())
        nav_model.eval()

        ans_model.load_state_dict(shared_ans_model.state_dict())
        ans_model.eval()
        ans_model.cuda()

        # that's a lot of numbers
        nav_metrics = NavMetric(
            info={
                'split': args.eval_split,
                'thread': rank
            },
            metric_names=[
                'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                'ep_len_30', 'ep_len_50'
            ],
            log_json=args.output_nav_log_path)

        vqa_metrics = VqaMetric(
            info={
                'split': args.eval_split,
                'thread': rank
            },
            metric_names=[
                'accuracy_10', 'accuracy_30', 'accuracy_50', 'mean_rank_10',
                'mean_rank_30', 'mean_rank_50', 'mean_reciprocal_rank_10',
                'mean_reciprocal_rank_30', 'mean_reciprocal_rank_50'
            ],
            log_json=args.output_ans_log_path)

        if 'pacman' in args.model_type:
            print('[eval] In If Pacman Condition')
            done = False

            while done == False:

                for batch in tqdm(eval_loader):

                    nav_model.load_state_dict(shared_nav_model.state_dict())
                    nav_model.eval()
                    nav_model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}
                    print('question is: ', question)
                    print('answer is: ', answer)
                    h3d = eval_loader.dataset.episode_house

                    ##########
                    #Sai analysis
                    # pdb.set_trace()
                    #print("Question idx ",idx)
                    if 364909 in idx:
                        pass
                        #print("Question changed")
                        #question = torch.tensor([[105,  25,  53,  94,  72,  50,  94,  11,   2,   0]])
                    else:
                        print("Question idx ", idx)
                        #print("not the selected question")
                        #sys.exit(1)
                    ##########

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if i > action_length[0]:
                            invalids.append([idx[0], i])
                            continue

                        question_var = Variable(question.cuda())

                        controller_step = False
                        planner_hidden = nav_model.planner_nav_rnn.init_hidden(
                            1)

                        # forward through planner till spawn
                        (
                            planner_actions_in, planner_img_feats,
                            controller_step, controller_action_in,
                            controller_img_feat, init_pos,
                            controller_action_counter
                        ) = eval_loader.dataset.get_hierarchical_features_till_spawn(
                            actions[0, :action_length[0] + 1].numpy(), i)

                        planner_actions_in_var = Variable(
                            planner_actions_in.cuda())
                        planner_img_feats_var = Variable(
                            planner_img_feats.cuda())

                        for step in range(planner_actions_in.size(0)):

                            planner_scores, planner_hidden = nav_model.planner_step(
                                question_var,
                                planner_img_feats_var[step].view(1, 1, 3200),
                                planner_actions_in_var[step].view(1, 1),
                                planner_hidden)

                        if controller_step == True:

                            controller_img_feat_var = Variable(
                                controller_img_feat.cuda())
                            controller_action_in_var = Variable(
                                torch.LongTensor(1, 1).fill_(
                                    int(controller_action_in)).cuda())

                            controller_scores = nav_model.controller_step(
                                controller_img_feat_var.view(1, 1, 3200),
                                controller_action_in_var.view(1, 1),
                                planner_hidden[0])

                            prob = F.softmax(controller_scores, dim=1)
                            controller_action = int(
                                prob.max(1)[1].data.cpu().numpy()[0])

                            if controller_action == 1:
                                controller_step = True
                            else:
                                controller_step = False

                            action = int(controller_action_in)
                            action_in = torch.LongTensor(1, 1).fill_(action +
                                                                     1).cuda()

                        else:

                            prob = F.softmax(planner_scores, dim=1)
                            action = int(prob.max(1)[1].data.cpu().numpy()[0])

                            action_in = torch.LongTensor(1, 1).fill_(action +
                                                                     1).cuda()

                        h3d.env.reset(x=init_pos[0],
                                      y=init_pos[2],
                                      yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append([idx[0], i])
                            continue

                        episode_length = 0
                        episode_done = True
                        controller_action_counter = 0

                        dists_to_target, pos_queue, pred_actions = [
                            init_dist_to_target
                        ], [init_pos], []
                        planner_actions, controller_actions = [], []

                        if action != 3:

                            # take the first step
                            img, _, _ = h3d.step(action)
                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = eval_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224,
                                                  224).cuda())).view(
                                                      1, 1, 3200)

                            for step in range(args.max_episode_length):

                                episode_length += 1

                                if controller_step == False:
                                    planner_scores, planner_hidden = nav_model.planner_step(
                                        question_var, img_feat_var,
                                        Variable(action_in), planner_hidden)

                                    prob = F.softmax(planner_scores, dim=1)
                                    action = int(
                                        prob.max(1)[1].data.cpu().numpy()[0])
                                    planner_actions.append(action)

                                pred_actions.append(action)
                                img, _, episode_done = h3d.step(action)

                                episode_done = episode_done or episode_length >= args.max_episode_length

                                img = torch.from_numpy(img.transpose(
                                    2, 0, 1)).float() / 255.0
                                img_feat_var = eval_loader.dataset.cnn(
                                    Variable(img.view(1, 3, 224,
                                                      224).cuda())).view(
                                                          1, 1, 3200)

                                dists_to_target.append(
                                    h3d.get_dist_to_target(h3d.env.cam.pos))
                                pos_queue.append([
                                    h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                    h3d.env.cam.pos.z, h3d.env.cam.yaw
                                ])

                                if episode_done == True:
                                    break

                                # query controller to continue or not
                                controller_action_in = Variable(
                                    torch.LongTensor(1,
                                                     1).fill_(action).cuda())
                                controller_scores = nav_model.controller_step(
                                    img_feat_var, controller_action_in,
                                    planner_hidden[0])

                                prob = F.softmax(controller_scores, dim=1)
                                controller_action = int(
                                    prob.max(1)[1].data.cpu().numpy()[0])

                                if controller_action == 1 and controller_action_counter < 4:
                                    controller_action_counter += 1
                                    controller_step = True
                                else:
                                    controller_action_counter = 0
                                    controller_step = False
                                    controller_action = 0

                                controller_actions.append(controller_action)

                                action_in = torch.LongTensor(
                                    1, 1).fill_(action + 1).cuda()

                        # run answerer here
                        if len(pos_queue) < 5:
                            pos_queue = eval_loader.dataset.episode_pos_queue[
                                len(pos_queue) - 5:] + pos_queue
                        images = eval_loader.dataset.get_frames(
                            h3d, pos_queue[-5:], preprocess=True)

                        #print((images.shape))

                        #for i, img in enumerate(images):
                        #    img = np.transpose(img, axes=(2,1,0))
                        #    cv2.imwrite('image_{}.png'.format(i),img)

                        images_var = Variable(
                            torch.from_numpy(images).cuda()).view(
                                1, 5, 3, 224, 224)
                        scores, att_probs = ans_model(images_var, question_var)
                        ans_acc, ans_rank = vqa_metrics.compute_ranks(
                            scores.data.cpu(), answer)

                        pred_answer = scores.max(1)[1].data[0]
                        #Himi changes for the awful keyerror
                        questionIdToToken = eval_loader.dataset.vocab[
                            'questionIdxToToken']
                        print(
                            '[Q_GT]', ' '.join([
                                eval_loader.dataset.vocab['questionIdxToToken']
                                [x.item()] for x in question[0] if x != 0
                            ]))
                        print(
                            '[A_GT]', ' '.join([
                                eval_loader.dataset.vocab['answerIdxToToken'][
                                    answer[0].item()]
                            ]))
                        print(
                            '[A_PRED]',
                            eval_loader.dataset.vocab['answerIdxToToken'][
                                pred_answer.item()])
                        #Himi
                        #print('Acc is: ', ans_acc)
                        #if 1 in ans_acc:
                        #    print("ACCURACY 1")
                        #print('Mean Rank Is: ', ans_rank)
                        # print(pred_answer)
                        # print('[A_PRED]', eval_loader.dataset.vocab['answerIdxToToken'][8])

                        # compute stats
                        metrics_slug['accuracy_' + str(i)] = ans_acc[0]
                        metrics_slug['mean_rank_' + str(i)] = ans_rank[0]
                        metrics_slug['mean_reciprocal_rank_' +
                                     str(i)] = 1.0 / ans_rank[0]

                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug[
                            'd_D_' +
                            str(i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' +
                                     str(i)] = np.array(dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # navigation metrics
                    metrics_list = []
                    for i in nav_metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(nav_metrics.metrics[
                                nav_metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    nav_metrics.update(metrics_list)

                    # vqa metrics
                    metrics_list = []
                    for i in vqa_metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(vqa_metrics.metrics[
                                vqa_metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    vqa_metrics.update(metrics_list)

                try:
                    print(nav_metrics.get_stat_string(mode=0))
                    print(vqa_metrics.get_stat_string(mode=0))
                except:
                    pass

                print('epoch', epoch)
                print('invalids', len(invalids))

                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        epoch += 1

        # checkpoints always #Himi changes
        if epoch % args.eval_every == 0 and args.to_log == 1:
            vqa_metrics.dump_log()
            nav_metrics.dump_log()

            model_state = get_state(nav_model)

            aad = dict(args.__dict__)
            ad = {}
            for i in aad:
                if i[0] != '_':
                    ad[i] = aad[i]

            checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

            checkpoint_path = '%s/epoch_%d_ans_50_%.04f.pt' % (
                args.checkpoint_dir, epoch, best_eval_acc)
            print('Saving checkpoint to %s' % checkpoint_path)
            torch.save(checkpoint, checkpoint_path)
        # checkpoint if best val accuracy
        if vqa_metrics.metrics[2][0] > best_eval_acc:  # ans_acc_50
            best_eval_acc = vqa_metrics.metrics[2][0]
            if epoch % args.eval_every == 0 and args.to_log == 1:
                vqa_metrics.dump_log()
                nav_metrics.dump_log()

                model_state = get_state(nav_model)

                aad = dict(args.__dict__)
                ad = {}
                for i in aad:
                    if i[0] != '_':
                        ad[i] = aad[i]

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_ans_50_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_ans_acc_50:%.04f]' % best_eval_acc)

        eval_loader.dataset._load_envs(start_idx=0, in_order=True)
    return best_eval_acc