Exemple #1
0
def train(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    shared_model.parameters()),
                             lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.to_cache
    }

    eval_loader_kwargs = {
        'questions_h5': args.val_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': 1,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'val',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.to_cache
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')
    eval_output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')
    metrics = VqaMetric(
        info={
            'split': 'train',
            'thread': rank
        },
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)

    eval_metrics = VqaMetric(
        info={'split': 'eval'},
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=eval_output_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)
    eval_loader = EqaDataLoader(**eval_loader_kwargs)

    if args.input_type == 'ques,image':
        train_loader.dataset._load_envs(start_idx=0, in_order=True)

    print('train_loader has %d samples' % len(train_loader.dataset))
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    t, epoch = 0, 0
    best_eval_acc = 0
    mean_rank = []

    while epoch < int(args.max_epochs):

        if args.input_type == 'ques':

            for batch in train_loader:

                t += 1

                model.load_state_dict(shared_model.state_dict())
                model.train()
                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # zero grad
                optim.zero_grad()

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                # backprop and update
                loss.backward()

                ensure_shared_grads(model.cpu(), shared_model)
                optim.step()

                if t % args.print_every == 0:
                    print(metrics.get_stat_string())
                    if args.to_log == 1:
                        metrics.dump_log()

        elif args.input_type == 'ques,image':

            t += 1

            #TRAIN
            model.train()
            model.cuda()
            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
            #p = 0
            while done == False:
                #print("Here now: ", epoch, p)
                #p+=1
                for num, batch in enumerate(train_loader):
                    #pp=0
                    #  done = True
                    #  break

                    model.load_state_dict(shared_model.state_dict())
                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch
                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    #zero grad
                    optim.zero_grad()

                    #  update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update([loss.item(), accuracy, ranks, 1.0 / ranks])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        if args.to_log == 1:
                            metrics.dump_log()

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        # SATYEN:
                        if args.to_cache == False:
                            train_loader.dataset._load_envs(start_idx=0,
                                                            in_order=True)
                else:
                    done = True

        if epoch % args.eval_every == 0:

            best_eval_acc = eval(0,
                                 args,
                                 model,
                                 best_eval_acc=best_eval_acc,
                                 epoch=epoch)

        epoch += 1
Exemple #2
0
def eval(rank, args, shared_model, best_eval_acc=0, checkpoint=None, epoch=0):
    print('Evaluating at {} epoch, with Acc {} ##################'.format(
        epoch, best_eval_acc))

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': 1,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.to_cache
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')

    # t, best_eval_acc = 0, 0, 0
    t = 0
    mean_rank = []

    model.load_state_dict(shared_model.state_dict())
    model.eval()

    metrics = VqaMetric(
        info={'split': args.eval_split},
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)

    if args.input_type == 'ques':
        for batch in eval_loader:
            t += 1

            model.cuda()

            idx, questions, answers = batch

            questions_var = Variable(questions.cuda())
            answers_var = Variable(answers.cuda())

            scores = model(questions_var)
            loss = lossFn(scores, answers_var)

            # update metrics
            accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
            metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

        print(metrics.get_stat_string(mode=0))

    elif args.input_type == 'ques,image':
        done = False
        all_envs_loaded = eval_loader.dataset._check_if_all_envs_loaded()

        while done == False:
            mean_rank = []
            for num, batch in enumerate(eval_loader):
                # if num>0:
                #     done= True
                #     break

                t += 1

                model.cuda()

                idx, questions, answers, images, _, _, _ = batch
                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())
                images_var = Variable(images.cuda())
                images_numpy = images_var.data.cpu().numpy()
                question_numpy = questions_var.data.cpu().numpy()
                answers_numpy = answers_var.data.cpu().numpy()
                scores, att_probs = model(images_var, questions_var)
                scores_numpy = scores.data.cpu().numpy()
                att_probs_numpy = att_probs.data.cpu().numpy()
                loss = lossFn(scores, answers_var)

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                mean_rank.extend(ranks)
                print("Batch Mean Ranks", sum(ranks) / len(ranks))

                metrics.update([loss.item(), accuracy, ranks, 1.0 / ranks])

            print(metrics.get_stat_string(mode=0))
            print("Mean Rank for eval", sum(mean_rank) / len(mean_rank))
            if all_envs_loaded == False:
                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True
            else:
                done = True

        # checkpoint if best val accuracy
        if metrics.metrics[1][0] >= best_eval_acc:
            best_eval_acc = metrics.metrics[1][0]

            metrics.dump_log()

            model_state = get_state(model)

            if args.checkpoint_path != False and checkpoint is not None:
                ad = checkpoint['args']
            else:
                ad = args.__dict__

            checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

            checkpoint_path = '%s/epoch_%d_accuracy_%d.pt' % (
                args.checkpoint_dir, epoch, int(best_eval_acc * 100))

            print('Saving checkpoint to %s' % checkpoint_path)
            torch.save(checkpoint, checkpoint_path)

        print('[best_eval_accuracy:%.04f]' % best_eval_acc)

    print("Mean Rank for eval", sum(mean_rank) / len(mean_rank))
    return best_eval_acc
Exemple #3
0
def eval(rank, args, shared_model):
    print('eval start...')

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'vocab': args.vocab_json,
        'batch_size': 1,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
    }

    # print(eval_loader_kwargs)

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0

    while epoch < int(args.max_epochs):

        model.load_state_dict(shared_model.state_dict())
        model.eval()

        metrics = VqaMetric(info={'split': args.eval_split},
                            metric_names=[
                                'loss', 'accuracy', 'mean_rank',
                                'mean_reciprocal_rank'
                            ],
                            log_json=args.output_log_path)

        if args.input_type == 'ques':
            for batch in eval_loader:
                t += 1

                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                print(scores)

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

            print(metrics.get_stat_string(mode=0))

        elif args.input_type == 'ques,image':
            done = False
            all_envs_loaded = True
            #all_envs_loaded = eval_loader.dataset._check_if_all_envs_loaded()

            while done == False:
                for batch in eval_loader:
                    t += 1

                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update(
                        [loss.data[0], accuracy, ranks, 1.0 / ranks])

                print(metrics.get_stat_string(mode=0))

                if all_envs_loaded == False:
                    eval_loader.dataset._load_envs()
                    if len(eval_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1

        # checkpoint if best val accuracy
        if metrics.metrics[1][0] > best_eval_acc:
            best_eval_acc = metrics.metrics[1][0]
            if epoch % args.eval_every == 0 and args.log == True:
                metrics.dump_log()

                model_state = get_state(model)

                if args.checkpoint_path != False:
                    ad = checkpoint['args']
                else:
                    ad = args.__dict__

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_accuracy_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_accuracy:%.04f]' % best_eval_acc)
def eval(rank, args, shared_nav_model, shared_ans_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        nav_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    model_kwargs = {'vocab': load_vocab(args.vocab_json)}
    ans_model = VqaLstmCnnAttentionModel(**model_kwargs)

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,
        'input_type': args.model_type,
        'num_frames': 5,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': False
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_nav_log_path = os.path.join(args.log_dir,
                                            'nav_eval_' + str(rank) + '.json')
    args.output_ans_log_path = os.path.join(args.log_dir,
                                            'ans_eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0.0

    while epoch < int(args.max_epochs):

        start_time = time.time()
        invalids = []

        nav_model.load_state_dict(shared_nav_model.state_dict())
        nav_model.eval()

        ans_model.load_state_dict(shared_ans_model.state_dict())
        ans_model.eval()
        ans_model.cuda()

        # that's a lot of numbers
        nav_metrics = NavMetric(
            info={
                'split': args.eval_split,
                'thread': rank
            },
            metric_names=[
                'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                'ep_len_30', 'ep_len_50'
            ],
            log_json=args.output_nav_log_path)

        vqa_metrics = VqaMetric(
            info={
                'split': args.eval_split,
                'thread': rank
            },
            metric_names=[
                'accuracy_10', 'accuracy_30', 'accuracy_50', 'mean_rank_10',
                'mean_rank_30', 'mean_rank_50', 'mean_reciprocal_rank_10',
                'mean_reciprocal_rank_30', 'mean_reciprocal_rank_50'
            ],
            log_json=args.output_ans_log_path)

        if 'pacman' in args.model_type:

            done = False

            while done == False:

                for batch in tqdm(eval_loader):

                    nav_model.load_state_dict(shared_nav_model.state_dict())
                    nav_model.eval()
                    nav_model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}

                    h3d = eval_loader.dataset.episode_house

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if i > action_length[0]:
                            invalids.append([idx[0], i])
                            continue

                        question_var = Variable(question.cuda())

                        controller_step = False
                        planner_hidden = nav_model.planner_nav_rnn.init_hidden(
                            1)

                        # forward through planner till spawn
                        planner_actions_in, planner_img_feats, controller_step, controller_action_in, controller_img_feat, init_pos = eval_loader.dataset.get_hierarchical_features_till_spawn(
                            actions[0, :action_length[0] + 1].numpy(), i)

                        planner_actions_in_var = Variable(
                            planner_actions_in.cuda())
                        planner_img_feats_var = Variable(
                            planner_img_feats.cuda())

                        for step in range(planner_actions_in.size(0)):

                            planner_scores, planner_hidden = nav_model.planner_step(
                                question_var,
                                planner_img_feats_var[step].view(1, 1, 3200),
                                planner_actions_in_var[step].view(1, 1),
                                planner_hidden)

                        if controller_step == True:

                            controller_img_feat_var = Variable(
                                controller_img_feat.cuda())
                            controller_action_in_var = Variable(
                                torch.LongTensor(1, 1).fill_(
                                    int(controller_action_in)).cuda())

                            controller_scores = nav_model.controller_step(
                                controller_img_feat_var.view(1, 1, 3200),
                                controller_action_in_var.view(1, 1),
                                planner_hidden[0])

                            prob = F.softmax(controller_scores, dim=1)
                            controller_action = int(
                                prob.max(1)[1].data.cpu().numpy()[0])

                            if controller_action == 1:
                                controller_step = True
                            else:
                                controller_step = False

                            action = int(controller_action_in)
                            action_in = torch.LongTensor(1, 1).fill_(action +
                                                                     1).cuda()

                        else:

                            prob = F.softmax(planner_scores, dim=1)
                            action = int(prob.max(1)[1].data.cpu().numpy()[0])

                            action_in = torch.LongTensor(1, 1).fill_(action +
                                                                     1).cuda()

                        h3d.env.reset(x=init_pos[0],
                                      y=init_pos[2],
                                      yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append([idx[0], i])
                            continue

                        episode_length = 0
                        episode_done = True
                        controller_action_counter = 0

                        dists_to_target, pos_queue, pred_actions = [
                            init_dist_to_target
                        ], [init_pos], []
                        planner_actions, controller_actions = [], []

                        if action != 3:

                            # take the first step
                            img, _, _ = h3d.step(action)
                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = eval_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224,
                                                  224).cuda())).view(
                                                      1, 1, 3200)

                            for step in range(args.max_episode_length):

                                episode_length += 1

                                if controller_step == False:
                                    planner_scores, planner_hidden = nav_model.planner_step(
                                        question_var, img_feat_var,
                                        Variable(action_in), planner_hidden)

                                    prob = F.softmax(planner_scores, dim=1)
                                    action = int(
                                        prob.max(1)[1].data.cpu().numpy()[0])
                                    planner_actions.append(action)

                                pred_actions.append(action)
                                img, _, episode_done = h3d.step(action)

                                episode_done = episode_done or episode_length >= args.max_episode_length

                                img = torch.from_numpy(img.transpose(
                                    2, 0, 1)).float() / 255.0
                                img_feat_var = eval_loader.dataset.cnn(
                                    Variable(img.view(1, 3, 224,
                                                      224).cuda())).view(
                                                          1, 1, 3200)

                                dists_to_target.append(
                                    h3d.get_dist_to_target(h3d.env.cam.pos))
                                pos_queue.append([
                                    h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                    h3d.env.cam.pos.z, h3d.env.cam.yaw
                                ])

                                if episode_done == True:
                                    break

                                # query controller to continue or not
                                controller_action_in = Variable(
                                    torch.LongTensor(1,
                                                     1).fill_(action).cuda())
                                controller_scores = nav_model.controller_step(
                                    img_feat_var, controller_action_in,
                                    planner_hidden[0])

                                prob = F.softmax(controller_scores, dim=1)
                                controller_action = int(
                                    prob.max(1)[1].data.cpu().numpy()[0])

                                if controller_action == 1 and controller_action_counter < 4:
                                    controller_action_counter += 1
                                    controller_step = True
                                else:
                                    controller_action_counter = 0
                                    controller_step = False
                                    controller_action = 0

                                controller_actions.append(controller_action)

                                action_in = torch.LongTensor(
                                    1, 1).fill_(action + 1).cuda()

                        # run answerer here
                        if len(pos_queue) < 5:
                            pos_queue = eval_loader.dataset.episode_pos_queue[
                                len(pos_queue) - 5:] + pos_queue
                        images = eval_loader.dataset.get_frames(
                            h3d, pos_queue[-5:], preprocess=True)
                        images_var = Variable(
                            torch.from_numpy(images).cuda()).view(
                                1, 5, 3, 224, 224)
                        scores, att_probs = ans_model(images_var, question_var)
                        ans_acc, ans_rank = vqa_metrics.compute_ranks(
                            scores.data.cpu(), answer)

                        pred_answer = scores.max(1)[1].data[0]

                        print(
                            '[Q_GT]', ' '.join([
                                eval_loader.dataset.vocab['questionIdxToToken']
                                [x] for x in question[0] if x != 0
                            ]))
                        print(
                            '[A_GT]',
                            eval_loader.dataset.vocab['answerIdxToToken'][
                                answer[0]])
                        print(
                            '[A_PRED]',
                            eval_loader.dataset.vocab['answerIdxToToken']
                            [pred_answer])

                        # compute stats
                        metrics_slug['accuracy_' + str(i)] = ans_acc[0]
                        metrics_slug['mean_rank_' + str(i)] = ans_rank[0]
                        metrics_slug['mean_reciprocal_rank_' +
                                     str(i)] = 1.0 / ans_rank[0]

                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug[
                            'd_D_' +
                            str(i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' +
                                     str(i)] = np.array(dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # navigation metrics
                    metrics_list = []
                    for i in nav_metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(nav_metrics.metrics[
                                nav_metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    nav_metrics.update(metrics_list)

                    # vqa metrics
                    metrics_list = []
                    for i in vqa_metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(vqa_metrics.metrics[
                                vqa_metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    vqa_metrics.update(metrics_list)

                try:
                    print(nav_metrics.get_stat_string(mode=0))
                    print(vqa_metrics.get_stat_string(mode=0))
                except:
                    pass

                print('epoch', epoch)
                print('invalids', len(invalids))

                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        epoch += 1

        # checkpoint if best val accuracy
        if vqa_metrics.metrics[2][0] > best_eval_acc:  # ans_acc_50
            best_eval_acc = vqa_metrics.metrics[2][0]
            if epoch % args.eval_every == 0 and args.to_log == 1:
                vqa_metrics.dump_log()
                nav_metrics.dump_log()

                model_state = get_state(nav_model)

                aad = dict(args.__dict__)
                ad = {}
                for i in aad:
                    if i[0] != '_':
                        ad[i] = aad[i]

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_ans_50_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_ans_acc_50:%.04f]' % best_eval_acc)

        eval_loader.dataset._load_envs(start_idx=0, in_order=True)
Exemple #5
0
def eval(rank, args, shared_model, use_vision, use_language):

    gpu_idx = args.gpus.index(args.gpus[rank % len(args.gpus)])
    torch.cuda.set_device(gpu_idx)
    print("eval gpu:" + str(gpu_idx) + " assigned")

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': 1,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.to_cache
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0

    print(epoch, args.max_epochs)  # DEBUG
    while epoch < int(args.max_epochs):
        print("eval gpu:" + str(gpu_idx) + " running epoch " + str(epoch))

        model.load_state_dict(shared_model.state_dict())
        model.eval()

        metrics = VqaMetric(info={'split': args.eval_split},
                            metric_names=[
                                'loss', 'accuracy', 'mean_rank',
                                'mean_reciprocal_rank'
                            ],
                            log_json=args.output_log_path)

        if args.input_type == 'ques':
            for batch in eval_loader:
                t += 1

                model.cuda()

                idx, questions, answers = batch

                # If not using language, replace each question with a start and end token back to back.
                if not use_language:
                    questions = torch.zeros_like(questions)
                    questions.fill_(
                        model_kwargs['vocab']['questionTokenToIdx']['<NULL>'])
                    questions[:, 0] = model_kwargs['vocab'][
                        'questionTokenToIdx']['<START>']
                    questions[:, 1] = model_kwargs['vocab'][
                        'questionTokenToIdx']['<END>']

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

            print(metrics.get_stat_string(mode=0))

        elif args.input_type == 'ques,image':
            done = False
            all_envs_loaded = eval_loader.dataset._check_if_all_envs_loaded()

            while done == False:
                for batch in eval_loader:
                    t += 1

                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    # If not using language, replace each question with a start and end token back to back.
                    if not use_language:
                        questions = torch.zeros_like(questions)
                        questions.fill_(model_kwargs['vocab']
                                        ['questionTokenToIdx']['<NULL>'])
                        questions[:, 0] = model_kwargs['vocab'][
                            'questionTokenToIdx']['<START>']
                        questions[:, 1] = model_kwargs['vocab'][
                            'questionTokenToIdx']['<END>']
                    # If not using vision, replace all image feature data with zeros.
                    if not use_vision:
                        images = torch.zeros_like(images)

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update(
                        [loss.data[0], accuracy, ranks, 1.0 / ranks])

                print(metrics.get_stat_string(mode=0))

                if all_envs_loaded == False:
                    eval_loader.dataset._load_envs()
                    if len(eval_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        read_epoch = None
        while read_epoch is None or epoch >= read_epoch:
            try:
                with open(args.identifier + '.shared_epoch.tmp', 'r') as f:
                    read_epoch = int(f.read().strip())
            except (IOError, ValueError):
                pass
            if read_epoch is None:
                # TODO: since merger, this no longer works (hanging); might need to undo changes re: threading that we
                # TODO: made or debug them.
                print("eval gpu:" + str(gpu_idx) +
                      " waiting for train thread to finish epoch " +
                      str(epoch))
                time.sleep(
                    10
                )  # sleep until the training thread finishes another iteration
        epoch = read_epoch

        # checkpoint if best val accuracy
        if metrics.metrics[1][0] > best_eval_acc:
            best_eval_acc = metrics.metrics[1][0]
            if epoch % args.eval_every == 0 and args.to_log == 1:
                metrics.dump_log()

                model_state = get_state(model)

                if args.checkpoint_path != False:
                    ad = checkpoint['args']
                else:
                    ad = args.__dict__

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_accuracy_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_accuracy:%.04f]' % best_eval_acc)
Exemple #6
0
def train(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    shared_model.parameters()),
                             lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    metrics = VqaMetric(
        info={
            'split': 'train',
            'thread': rank
        },
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)
    print('[TRAIN_LOADER] start')
    train_loader = EqaDataLoader(**train_loader_kwargs)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):
        # print('epoch no. %d' % epoch)

        if args.input_type == 'ques':

            for batch in train_loader:

                t += 1

                model.load_state_dict(shared_model.state_dict())
                model.train()
                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # zero grad
                optim.zero_grad()

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                # backprop and update
                loss.backward()

                ensure_shared_grads(model.cpu(), shared_model)
                optim.step()

                if t % args.print_every == 0:
                    print(metrics.get_stat_string())
                    if args.log == True:
                        metrics.dump_log()

        elif args.input_type == 'ques,image':

            done = False
            all_envs_loaded = True
            #all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cnn.eval()
                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    print('--- images dim {}'.format(images.size()))

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    print('--- att_probs: {}.'.format(att_probs))
                    # print('--- answers_var: {}.'.format(answers_var))
                    # print('--- loss: {}.'.format(loss))

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update([loss.item(), accuracy, ranks, 1.0 / ranks])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        if args.log == True:
                            metrics.dump_log()

                if all_envs_loaded == False:
                    print('[CHECK][Cache:%d][Total:%d]' %
                          (len(train_loader.dataset.img_data_cache),
                           len(train_loader.dataset.env_list)))
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1
Exemple #7
0
def train(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    shared_model.parameters()),
                             lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.to_cache
    }

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': 1,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.to_cache
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    metrics = VqaMetric(
        info={
            'split': 'train',
            'thread': rank
        },
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    train_loader = EqaDataLoader(**train_loader_kwargs)
    if args.input_type == 'ques,image':
        train_loader.dataset._load_envs(start_idx=0, in_order=True)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch, best_eval_acc = 0, 0, 0

    while epoch < int(args.max_epochs):

        if args.input_type == 'ques':

            for batch in train_loader:

                t += 1

                model.load_state_dict(shared_model.state_dict())
                model.train()
                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # zero grad
                optim.zero_grad()

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                # backprop and update
                loss.backward()

                ensure_shared_grads(model.cpu(), shared_model)
                optim.step()

                if t % args.print_every == 0:
                    print(metrics.get_stat_string())
                    if args.to_log == 1:
                        metrics.dump_log()

        elif args.input_type == 'ques,image':

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    # accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
                    # metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    #if t % args.print_every == 0:
                    #    print(metrics.get_stat_string())
                    #    if args.to_log == 1:
                    #        metrics.dump_log()

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

            env_done = False
            env_all_envs_loaded = eval_loader.dataset._check_if_all_envs_loaded(
            )

            while env_done == False:
                _loss, _accuracy, _ranks = None, None, None
                for batch in eval_loader:
                    t += 1

                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)

                if _loss is None:
                    _loss = loss.data[0]
                    _accuracy = accuracy
                    _ranks = ranks
                else:
                    _loss = torch.cat([_loss, loss.data[0]])
                    _accuracy = torch.cat([_accuracy, accuracy])
                    _ranks = torch.cat([_ranks, ranks])

                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                print(metrics.get_stat_string(mode=0))

                if env_all_envs_loaded == False:
                    eval_loader.dataset._load_envs()
                    if len(eval_loader.dataset.pruned_env_set) == 0:
                        env_done = True
                else:
                    env_done = True

        epoch += 1

        # checkpoint if best val accuracy
        if metrics.metrics[1][0] > best_eval_acc:
            best_eval_acc = metrics.metrics[1][0]
            if epoch % args.eval_every == 0 and args.to_log == 1:
                metrics.dump_log()

                model_state = get_state(model)

                if args.checkpoint_path != False:
                    ad = checkpoint['args']
                else:
                    ad = args.__dict__

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_accuracy_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_accuracy:%.04f]' % best_eval_acc)
Exemple #8
0
def train(rank, args, shared_model, use_vision, use_language):

    gpu_idx = args.gpus.index(args.gpus[rank % len(args.gpus)])
    torch.cuda.set_device(gpu_idx)
    print("train gpu:" + str(gpu_idx) + " assigned")

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    shared_model.parameters()),
                             lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.to_cache
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    metrics = VqaMetric(
        info={
            'split': 'train',
            'thread': rank
        },
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)
    train_loader = EqaDataLoader(**train_loader_kwargs)
    if args.input_type == 'ques,image':
        train_loader.dataset._load_envs(start_idx=0, in_order=True)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):
        print("train gpu:" + str(gpu_idx) + " running epoch " + str(epoch))

        if args.input_type == 'ques':

            for batch in train_loader:

                t += 1

                model.load_state_dict(shared_model.state_dict())
                model.train()
                model.cuda()

                idx, questions, answers = batch

                # If not using language, replace each question with a start and end token back to back.
                if not use_language:
                    questions = torch.zeros_like(questions)
                    questions.fill_(
                        model_kwargs['vocab']['questionTokenToIdx']['<NULL>'])
                    questions[:, 0] = model_kwargs['vocab'][
                        'questionTokenToIdx']['<START>']
                    questions[:, 1] = model_kwargs['vocab'][
                        'questionTokenToIdx']['<END>']

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # zero grad
                optim.zero_grad()

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                # backprop and update
                loss.backward()

                ensure_shared_grads(model.cpu(), shared_model)
                optim.step()

                if t % args.print_every == 0:
                    print(metrics.get_stat_string())
                    if args.to_log == 1:
                        metrics.dump_log()

        elif args.input_type == 'ques,image':

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    # If not using language, replace each question with a start and end token back to back.
                    if not use_language:
                        questions = torch.zeros_like(questions)
                        questions.fill_(model_kwargs['vocab']
                                        ['questionTokenToIdx']['<NULL>'])
                        questions[:, 0] = model_kwargs['vocab'][
                            'questionTokenToIdx']['<START>']
                        questions[:, 1] = model_kwargs['vocab'][
                            'questionTokenToIdx']['<END>']
                    # If not using vision, replace all image feature data with zeros.
                    if not use_vision:
                        images = torch.zeros_like(images)

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update(
                        [loss.data[0], accuracy, ranks, 1.0 / ranks])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        if args.to_log == 1:
                            metrics.dump_log()

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        # Set shared epoch when it finishes on the training side
        with open(args.identifier + '.shared_epoch.tmp', 'w') as f:
            f.write(str(epoch))

        epoch += 1
Exemple #9
0
def train(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.forget_rate is None:
        forget_rate = args.noise_rate
    else:
        forget_rate = args.forget_rate

    rate_schedule = np.ones(args.max_epochs) * forget_rate
    rate_schedule[:args.num_gradual] = np.linspace(0,
                                                   forget_rate**args.exponent,
                                                   args.num_gradual)

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model1 = VqaLstmCnnAttentionModel(**model_kwargs)

    # lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    shared_model.parameters()),
                             lr=args.learning_rate)

    optim1 = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                     model1.parameters()),
                              lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.cache
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    metrics = VqaMetric(
        info={
            'split': 'train',
            'thread': rank
        },
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)
    if args.input_type == 'ques,image':
        train_loader.dataset._load_envs(start_idx=0, in_order=True)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):

        if args.input_type == 'ques':

            for batch in train_loader:

                t += 1

                model.load_state_dict(shared_model.state_dict())
                model.train()
                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # zero grad
                optim.zero_grad()

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update(
                    [loss.data.item(), accuracy, ranks, 1.0 / ranks])

                # backprop and update
                loss.backward()

                ensure_shared_grads(model.cpu(), shared_model)
                optim.step()

                if t % args.print_every == 0:
                    print(metrics.get_stat_string())
                    if args.log == True:
                        metrics.dump_log()

        elif args.input_type == 'ques,image':

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cnn.eval()
                    model.cuda()

                    # model1.load_state_dict(shared_model1.state_dict())
                    model1.train()
                    model1.cnn.eval()
                    model1.cuda()

                    idx, questions, answers, images, _, _, _ = batch
                    for i in range(len(answers)):
                        if random.random() < args.noise_rate:
                            tempt = random.randint(-10, 10)
                            answers[i] = answers[i] + tempt
                            while (answers[i] < 0 or answers[i] > 70):
                                answers[i] = answers[i] - tempt
                                tempt = random.randint(-10, 10)
                                answers[i] = answers[i] + tempt
                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    scores1, att_probs1 = model1(images_var, questions_var)
                    loss, loss1 = loss_coteaching(scores, scores1, answers_var,
                                                  rate_schedule[epoch])

                    # zero grad
                    optim.zero_grad()
                    optim1.zero_grad()

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update(
                        [loss.data.item(), accuracy, ranks, 1.0 / ranks])

                    # backprop and update
                    loss.backward()
                    loss1.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    # ensure_shared_grads(model1.cpu(), shared_model1)
                    optim.step()
                    optim1.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        if args.log == True:
                            metrics.dump_log()

                if all_envs_loaded == False:
                    print('[CHECK][Cache:%d][Total:%d]' %
                          (len(train_loader.dataset.img_data_cache),
                           len(train_loader.dataset.env_list)))
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1
Exemple #10
0
def eval(rank, args, shared_nav_model, shared_ans_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        nav_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    model_kwargs = {'vocab': load_vocab(args.vocab_json)}
    ans_model = VqaLstmCnnAttentionModel(**model_kwargs)

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,
        'input_type': args.model_type,
        'num_frames': 5,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': False
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_nav_log_path = os.path.join(args.log_dir,
                                            'nav_eval_' + str(rank) + '.json')
    args.output_ans_log_path = os.path.join(args.log_dir,
                                            'ans_eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0.0

    while epoch < int(args.max_epochs):

        start_time = time.time()
        invalids = []

        nav_model.load_state_dict(shared_nav_model.state_dict())
        nav_model.eval()

        ans_model.load_state_dict(shared_ans_model.state_dict())
        ans_model.eval()
        ans_model.cuda()

        # that's a lot of numbers
        nav_metrics = NavMetric(
            info={'split': args.eval_split,
                  'thread': rank},
            metric_names=[
                'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                'ep_len_30', 'ep_len_50'
            ],
            log_json=args.output_nav_log_path)

        vqa_metrics = VqaMetric(
            info={'split': args.eval_split,
                  'thread': rank},
            metric_names=[
                'accuracy_10', 'accuracy_30', 'accuracy_50', 'mean_rank_10',
                'mean_rank_30', 'mean_rank_50', 'mean_reciprocal_rank_10',
                'mean_reciprocal_rank_30', 'mean_reciprocal_rank_50'
            ],
            log_json=args.output_ans_log_path)

        if 'pacman' in args.model_type:

            done = False

            while done == False:

                for batch in tqdm(eval_loader):

                    nav_model.load_state_dict(shared_nav_model.state_dict())
                    nav_model.eval()
                    nav_model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}

                    h3d = eval_loader.dataset.episode_house

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if i > action_length[0]:
                            invalids.append([idx[0], i])
                            continue

                        question_var = Variable(question.cuda())

                        controller_step = False
                        planner_hidden = nav_model.planner_nav_rnn.init_hidden(
                            1)

                        # forward through planner till spawn
                        planner_actions_in, planner_img_feats, controller_step, controller_action_in, controller_img_feat, init_pos = eval_loader.dataset.get_hierarchical_features_till_spawn(
                            actions[0, :action_length[0] + 1].numpy(), i)

                        planner_actions_in_var = Variable(
                            planner_actions_in.cuda())
                        planner_img_feats_var = Variable(
                            planner_img_feats.cuda())

                        for step in range(planner_actions_in.size(0)):

                            planner_scores, planner_hidden = nav_model.planner_step(
                                question_var, planner_img_feats_var[step].view(
                                    1, 1,
                                    3200), planner_actions_in_var[step].view(
                                        1, 1), planner_hidden)

                        if controller_step == True:

                            controller_img_feat_var = Variable(
                                controller_img_feat.cuda())
                            controller_action_in_var = Variable(
                                torch.LongTensor(1, 1).fill_(
                                    int(controller_action_in)).cuda())

                            controller_scores = nav_model.controller_step(
                                controller_img_feat_var.view(1, 1, 3200),
                                controller_action_in_var.view(1, 1),
                                planner_hidden[0])

                            prob = F.softmax(controller_scores, dim=1)
                            controller_action = int(
                                prob.max(1)[1].data.cpu().numpy()[0])

                            if controller_action == 1:
                                controller_step = True
                            else:
                                controller_step = False

                            action = int(controller_action_in)
                            action_in = torch.LongTensor(
                                1, 1).fill_(action + 1).cuda()

                        else:

                            prob = F.softmax(planner_scores, dim=1)
                            action = int(prob.max(1)[1].data.cpu().numpy()[0])

                            action_in = torch.LongTensor(
                                1, 1).fill_(action + 1).cuda()

                        h3d.env.reset(
                            x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append([idx[0], i])
                            continue

                        episode_length = 0
                        episode_done = True
                        controller_action_counter = 0

                        dists_to_target, pos_queue, pred_actions = [
                            init_dist_to_target
                        ], [init_pos], []
                        planner_actions, controller_actions = [], []

                        if action != 3:

                            # take the first step
                            img, _, _ = h3d.step(action)
                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = eval_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224,
                                                  224).cuda())).view(
                                                      1, 1, 3200)

                            for step in range(args.max_episode_length):

                                episode_length += 1

                                if controller_step == False:
                                    planner_scores, planner_hidden = nav_model.planner_step(
                                        question_var, img_feat_var,
                                        Variable(action_in), planner_hidden)

                                    prob = F.softmax(planner_scores, dim=1)
                                    action = int(
                                        prob.max(1)[1].data.cpu().numpy()[0])
                                    planner_actions.append(action)

                                pred_actions.append(action)
                                img, _, episode_done = h3d.step(action)

                                episode_done = episode_done or episode_length >= args.max_episode_length

                                img = torch.from_numpy(img.transpose(
                                    2, 0, 1)).float() / 255.0
                                img_feat_var = eval_loader.dataset.cnn(
                                    Variable(img.view(1, 3, 224, 224)
                                             .cuda())).view(1, 1, 3200)

                                dists_to_target.append(
                                    h3d.get_dist_to_target(h3d.env.cam.pos))
                                pos_queue.append([
                                    h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                    h3d.env.cam.pos.z, h3d.env.cam.yaw
                                ])

                                if episode_done == True:
                                    break

                                # query controller to continue or not
                                controller_action_in = Variable(
                                    torch.LongTensor(1,
                                                     1).fill_(action).cuda())
                                controller_scores = nav_model.controller_step(
                                    img_feat_var, controller_action_in,
                                    planner_hidden[0])

                                prob = F.softmax(controller_scores, dim=1)
                                controller_action = int(
                                    prob.max(1)[1].data.cpu().numpy()[0])

                                if controller_action == 1 and controller_action_counter < 4:
                                    controller_action_counter += 1
                                    controller_step = True
                                else:
                                    controller_action_counter = 0
                                    controller_step = False
                                    controller_action = 0

                                controller_actions.append(controller_action)

                                action_in = torch.LongTensor(
                                    1, 1).fill_(action + 1).cuda()

                        # run answerer here
                        if len(pos_queue) < 5:
                            pos_queue = eval_loader.dataset.episode_pos_queue[len(
                                pos_queue) - 5:] + pos_queue
                        images = eval_loader.dataset.get_frames(
                            h3d, pos_queue[-5:], preprocess=True)
                        images_var = Variable(
                            torch.from_numpy(images).cuda()).view(
                                1, 5, 3, 224, 224)
                        scores, att_probs = ans_model(images_var, question_var)
                        ans_acc, ans_rank = vqa_metrics.compute_ranks(
                            scores.data.cpu(), answer)

                        pred_answer = scores.max(1)[1].data[0]

                        print('[Q_GT]', ' '.join([
                            eval_loader.dataset.vocab['questionIdxToToken'][x]
                            for x in question[0] if x != 0
                        ]))
                        print('[A_GT]', eval_loader.dataset.vocab[
                            'answerIdxToToken'][answer[0]])
                        print('[A_PRED]', eval_loader.dataset.vocab[
                            'answerIdxToToken'][pred_answer])

                        # compute stats
                        metrics_slug['accuracy_' + str(i)] = ans_acc[0]
                        metrics_slug['mean_rank_' + str(i)] = ans_rank[0]
                        metrics_slug['mean_reciprocal_rank_'
                                     + str(i)] = 1.0 / ans_rank[0]

                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug['d_D_' + str(
                            i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = np.array(
                            dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # navigation metrics
                    metrics_list = []
                    for i in nav_metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(nav_metrics.metrics[
                                nav_metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    nav_metrics.update(metrics_list)

                    # vqa metrics
                    metrics_list = []
                    for i in vqa_metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(vqa_metrics.metrics[
                                vqa_metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    vqa_metrics.update(metrics_list)

                try:
                    print(nav_metrics.get_stat_string(mode=0))
                    print(vqa_metrics.get_stat_string(mode=0))
                except:
                    pass

                print('epoch', epoch)
                print('invalids', len(invalids))

                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        epoch += 1

        # checkpoint if best val accuracy
        if vqa_metrics.metrics[2][0] > best_eval_acc:  # ans_acc_50
            best_eval_acc = vqa_metrics.metrics[2][0]
            if epoch % args.eval_every == 0 and args.to_log == 1:
                vqa_metrics.dump_log()
                nav_metrics.dump_log()

                model_state = get_state(nav_model)

                aad = dict(args.__dict__)
                ad = {}
                for i in aad:
                    if i[0] != '_':
                        ad[i] = aad[i]

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_ans_50_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_ans_acc_50:%.04f]' % best_eval_acc)

        eval_loader.dataset._load_envs(start_idx=0, in_order=True)
Exemple #11
0
def eval(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': 1,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank%len(args.gpus)],
        'to_cache': args.to_cache
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0

    while epoch < int(args.max_epochs):

        model.load_state_dict(shared_model.state_dict())
        model.eval()

        metrics = VqaMetric(
            info={'split': args.eval_split},
            metric_names=[
                'loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'
            ],
            log_json=args.output_log_path)

        if args.input_type == 'ques':
            for batch in eval_loader:
                t += 1

                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

            print(metrics.get_stat_string(mode=0))

        elif args.input_type == 'ques,image':
            done = False
            all_envs_loaded = eval_loader.dataset._check_if_all_envs_loaded()

            while done == False:
                for batch in eval_loader:
                    t += 1

                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update(
                        [loss.data[0], accuracy, ranks, 1.0 / ranks])

                print(metrics.get_stat_string(mode=0))

                if all_envs_loaded == False:
                    eval_loader.dataset._load_envs()
                    if len(eval_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1

        # checkpoint if best val accuracy
        if metrics.metrics[1][0] > best_eval_acc:
            best_eval_acc = metrics.metrics[1][0]
            if epoch % args.eval_every == 0 and args.to_log == 1:
                metrics.dump_log()

                model_state = get_state(model)

                if args.checkpoint_path != False:
                    ad = checkpoint['args']
                else:
                    ad = args.__dict__

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_accuracy_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_accuracy:%.04f]' % best_eval_acc)
Exemple #12
0
def train(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adam(
        filter(lambda p: p.requires_grad, shared_model.parameters()),
        lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank%len(args.gpus)],
        'to_cache': args.to_cache
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    metrics = VqaMetric(
        info={'split': 'train',
              'thread': rank},
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)
    if args.input_type == 'ques,image':
        train_loader.dataset._load_envs(start_idx=0, in_order=True)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):

        if args.input_type == 'ques':

            for batch in train_loader:

                t += 1

                model.load_state_dict(shared_model.state_dict())
                model.train()
                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # zero grad
                optim.zero_grad()

                # update metrics
                accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                # backprop and update
                loss.backward()

                ensure_shared_grads(model.cpu(), shared_model)
                optim.step()

                if t % args.print_every == 0:
                    print(metrics.get_stat_string())
                    if args.to_log == 1:
                        metrics.dump_log()

        elif args.input_type == 'ques,image':

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
                    metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        if args.to_log == 1:
                            metrics.dump_log()

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1
Exemple #13
0
def train(rank, args, shared_nav_model, shared_ans_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        nav_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    model_kwargs = {'vocab': load_vocab(args.vocab_json)}
    ans_model = VqaLstmCnnAttentionModel(**model_kwargs)

    optim = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                   shared_nav_model.parameters()),
                            lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,  # FOR REINFORCE!!!
        'input_type': args.model_type,
        'num_frames': 5,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.to_cache
    }

    args.output_nav_log_path = os.path.join(args.log_dir,
                                            'nav_train_' + str(rank) + '.json')
    args.output_ans_log_path = os.path.join(args.log_dir,
                                            'ans_train_' + str(rank) + '.json')

    nav_model.load_state_dict(shared_nav_model.state_dict())
    nav_model.cuda()

    ans_model.load_state_dict(shared_ans_model.state_dict())
    ans_model.eval()
    ans_model.cuda()

    # Saty: add coverage metric here (Should be inculcated into the reward structure)
    nav_metrics = NavMetric(info={
        'split': 'train',
        'thread': rank
    },
                            metric_names=[
                                'planner_loss', 'controller_loss', 'reward',
                                'episode_length'
                            ],
                            log_json=args.output_nav_log_path)

    vqa_metrics = VqaMetric(
        info={
            'split': 'train',
            'thread': rank
        },
        metric_names=['accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_ans_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0
    p_losses, c_losses, reward_list, episode_length_list = [], [], [], []

    nav_metrics.update([10.0, 10.0, 0, 100])

    mult = 0.1
    best_eval_acc = 0.0
    while epoch < int(args.max_epochs):

        print('###############################')
        print('[train] Epoch is:', epoch)
        print('###############################')

        if 'pacman' in args.model_type:

            planner_lossFn = MaskedNLLCriterion().cuda()
            controller_lossFn = MaskedNLLCriterion().cuda()

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            #himi changes
            envs = 1

            while done == False:

                for batch in train_loader:

                    nav_model.load_state_dict(shared_nav_model.state_dict())
                    nav_model.eval()
                    nav_model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}

                    h3d = train_loader.dataset.episode_house

                    # evaluate at multiple initializations
                    # for i in [10, 30, 50]:

                    t += 1

                    question_var = Variable(question.cuda())

                    controller_step = False
                    planner_hidden = nav_model.planner_nav_rnn.init_hidden(1)

                    # forward through planner till spawn
                    (planner_actions_in, planner_img_feats, controller_step, controller_action_in, controller_img_feat, init_pos, _ ) = \
                    train_loader.dataset.get_hierarchical_features_till_spawn(
                        actions[0, :action_length[0] + 1].numpy(),
                        max(3, int(mult * action_length[0])))

                    planner_actions_in_var = Variable(
                        planner_actions_in.cuda())
                    planner_img_feats_var = Variable(planner_img_feats.cuda())

                    #  Sati: Run Planner till target_pos_idx -> This is from get_hierarchical_features...
                    # need T-1 hidden state for planner! -> GT img_feats, Actions and question!
                    for step in range(planner_actions_in.size(0)):

                        planner_scores, planner_hidden = \
                        nav_model.planner_step(
                            question_var,
                            planner_img_feats_var[step].view(1, 1, 3200),
                            planner_actions_in_var[step].view(1, 1),
                            planner_hidden)

                    if controller_step == True:

                        controller_img_feat_var = Variable(
                            controller_img_feat.cuda())
                        controller_action_in_var = Variable(
                            torch.LongTensor(1, 1).fill_(
                                int(controller_action_in)).cuda())

                        controller_scores = nav_model.controller_step(
                            controller_img_feat_var.view(1, 1, 3200),
                            controller_action_in_var.view(1, 1),
                            planner_hidden[0])

                        prob = F.softmax(controller_scores, dim=1)
                        controller_action = int(
                            prob.max(1)[1].data.cpu().numpy()[0])

                        if controller_action == 1:
                            controller_step = True
                        else:
                            controller_step = False

                        action = int(controller_action_in)
                        action_in = torch.LongTensor(1, 1).fill_(action +
                                                                 1).cuda()

                    else:

                        prob = F.softmax(planner_scores, dim=1)
                        action = int(prob.max(1)[1].data.cpu().numpy()[0])

                        action_in = torch.LongTensor(1, 1).fill_(action +
                                                                 1).cuda()

                    h3d.env.reset(x=init_pos[0],
                                  y=init_pos[2],
                                  yaw=init_pos[3])

                    init_dist_to_target = h3d.get_dist_to_target(
                        h3d.env.cam.pos)
                    if init_dist_to_target < 0:  # unreachable
                        invalids.append([idx[0], i])
                        continue

                    episode_length = 0
                    episode_done = True
                    controller_action_counter = 0

                    dists_to_target, pos_queue = [init_dist_to_target
                                                  ], [init_pos]

                    rewards, planner_actions, planner_log_probs, controller_actions, controller_log_probs = [], [], [], [], []

                    if action != 3:

                        # take the first step -> Include coverage in reward!
                        img, rwd, episode_done = h3d.step(action,
                                                          step_reward=True)
                        img = torch.from_numpy(img.transpose(
                            2, 0, 1)).float() / 255.0
                        img_feat_var = train_loader.dataset.cnn(
                            Variable(img.view(1, 3, 224,
                                              224).cuda())).view(1, 1, 3200)

                        for step in range(args.max_episode_length):

                            episode_length += 1

                            if controller_step == False:
                                planner_scores, planner_hidden = nav_model.planner_step(
                                    question_var, img_feat_var,
                                    Variable(action_in), planner_hidden)

                                planner_prob = F.softmax(planner_scores, dim=1)
                                planner_log_prob = F.log_softmax(
                                    planner_scores, dim=1)

                                action = planner_prob.multinomial(
                                    num_samples=1).data
                                planner_log_prob = planner_log_prob.gather(
                                    1, Variable(action))

                                planner_log_probs.append(
                                    planner_log_prob.cpu())

                                action = int(action.cpu().numpy()[0, 0])
                                planner_actions.append(action)

                            img, rwd, episode_done = h3d.step(action,
                                                              step_reward=True)

                            episode_done = episode_done or episode_length >= args.max_episode_length

                            rewards.append(rwd)

                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = train_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224,
                                                  224).cuda())).view(
                                                      1, 1, 3200)

                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done == True:
                                break

                            # query controller to continue or not
                            controller_action_in = Variable(
                                torch.LongTensor(1, 1).fill_(action).cuda())
                            controller_scores = nav_model.controller_step(
                                img_feat_var, controller_action_in,
                                planner_hidden[0])

                            controller_prob = F.softmax(controller_scores,
                                                        dim=1)
                            controller_log_prob = F.log_softmax(
                                controller_scores, dim=1)

                            controller_action = controller_prob.multinomial(
                                num_samples=1).data

                            if int(controller_action[0]
                                   ) == 1 and controller_action_counter < 4:
                                controller_action_counter += 1
                                controller_step = True
                            else:
                                controller_action_counter = 0
                                controller_step = False
                                controller_action.fill_(0)

                            controller_log_prob = controller_log_prob.gather(
                                1, Variable(controller_action))
                            controller_log_probs.append(
                                controller_log_prob.cpu())

                            controller_action = int(
                                controller_action.cpu().numpy()[0, 0])
                            controller_actions.append(controller_action)
                            action_in = torch.LongTensor(1, 1).fill_(action +
                                                                     1).cuda()

                    # run answerer here
                    ans_acc = [0]
                    if action == 3:
                        if len(pos_queue) < 5:
                            pos_queue = train_loader.dataset.episode_pos_queue[
                                len(pos_queue) - 5:] + pos_queue
                        images = train_loader.dataset.get_frames(
                            h3d, pos_queue[-5:], preprocess=True)
                        images_var = Variable(
                            torch.from_numpy(images).cuda()).view(
                                1, 5, 3, 224, 224)
                        scores, att_probs = ans_model(images_var, question_var)
                        ans_acc, ans_rank = vqa_metrics.compute_ranks(
                            scores.data.cpu(), answer)
                        vqa_metrics.update([ans_acc, ans_rank, 1.0 / ans_rank])

                    rewards.append(h3d.success_reward * ans_acc[0])

                    R = torch.zeros(1, 1)

                    planner_loss = 0
                    controller_loss = 0

                    planner_rev_idx = -1
                    for i in reversed(range(len(rewards))):
                        R = 0.99 * R + rewards[i]
                        advantage = R - nav_metrics.metrics[2][1]

                        if i < len(controller_actions):
                            controller_loss = controller_loss - controller_log_probs[
                                i] * Variable(advantage)

                            if controller_actions[
                                    i] == 0 and planner_rev_idx + len(
                                        planner_log_probs) >= 0:
                                planner_loss = planner_loss - planner_log_probs[
                                    planner_rev_idx] * Variable(advantage)
                                planner_rev_idx -= 1

                        elif planner_rev_idx + len(planner_log_probs) >= 0:

                            planner_loss = planner_loss - planner_log_probs[
                                planner_rev_idx] * Variable(advantage)
                            planner_rev_idx -= 1

                    controller_loss /= max(1, len(controller_log_probs))
                    planner_loss /= max(1, len(planner_log_probs))

                    optim.zero_grad()

                    if isinstance(planner_loss, float) == False and isinstance(
                            controller_loss, float) == False:
                        p_losses.append(planner_loss.data[0, 0])
                        c_losses.append(controller_loss.data[0, 0])
                        reward_list.append(np.sum(rewards))
                        episode_length_list.append(episode_length)

                        (planner_loss + controller_loss).backward()

                        ensure_shared_grads(nav_model.cpu(), shared_nav_model)
                        optim.step()

                    if len(reward_list) > 50:

                        nav_metrics.update([
                            p_losses, c_losses, reward_list,
                            episode_length_list
                        ])
                        envs += 1
                        print('[train train_eqa.py] Envs is: ', envs)

                        print(nav_metrics.get_stat_string())

                        if args.to_log == 1 and envs % 10 == 0:
                            vqa_metrics.dump_log()
                            nav_metrics.dump_log()

                            model_state = get_state(nav_model)

                            aad = dict(args.__dict__)
                            ad = {}
                            for i in aad:
                                if i[0] != '_':
                                    ad[i] = aad[i]

                            checkpoint = {
                                'args': ad,
                                'state': model_state,
                                'epoch': epoch
                            }

                            checkpoint_path = '%s/epoch_%d_ans_10_envs_%d.pt' % (
                                args.checkpoint_dir, epoch, envs)
                            print('Saving checkpoint to %s' % checkpoint_path)
                            torch.save(checkpoint, checkpoint_path)
                        ##################################
                        if args.to_log == 1:
                            nav_metrics.dump_log()

                        if nav_metrics.metrics[2][1] > 0.35:
                            mult = min(mult + 0.1, 1.0)

                        p_losses, c_losses, reward_list, episode_length_list = [], [], [], []

                #Himi changes, checking every 100
                print('Out of that loop')
                # if args.to_log == 1 and envs % 2 == 0 :
                #     vqa_metrics.dump_log()
                #     nav_metrics.dump_log()

                #     model_state = get_state(nav_model)

                #     aad = dict(args.__dict__)
                #     ad = {}
                #     for i in aad:
                #         if i[0] != '_':
                #             ad[i] = aad[i]

                #     checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                #     checkpoint_path = '%s/epoch_%d_ans_10_envs_%d.pt' % (
                #         args.checkpoint_dir, epoch, envs)
                #     print('Saving checkpoint to %s' % checkpoint_path)
                #     torch.save(checkpoint, checkpoint_path)
                # ##################################

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.to_cache == False:
                            train_loader.dataset._load_envs(start_idx=0,
                                                            in_order=True)
                else:
                    done = True

        epoch += 1
        best_eval_acc = eval(0, args, nav_model, ans_model, best_eval_acc,
                             epoch)