예제 #1
0
def main(args):

    # set up model
    checkpoint_path = osp.join(args.checkpoint_dir, '%s.pth' % args.id)
    checkpoint = torch.load(checkpoint_path)
    opt = checkpoint['opt']
    model = Navigator(checkpoint['opt'])
    model.load_state_dict(checkpoint['model_state'])
    model.cuda()
    print('model set up.')

    # set up criterions
    nll_crit = SeqModelCriterion().cuda()
    mse_crit = MaskedMSELoss().cuda()

    # set up loader
    loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': args.split,
        'max_seq_length': 100,
        'requires_imgs': False,
        'nav_types': ['room'],
        'question_types': ['all'],
    }
    dataset = NavImitationDataset(**loader_kwargs)

    # evaluate
    predictions, overall_nll, overall_teacher_forcing_acc, overall_mse, Nav_nll, Nav_teacher_forcing_acc, Nav_cnt = \
          evaluate(dataset, model, nll_crit, mse_crit, opt)

    # summarize
    results_str = 'id[%s] ' % args.id
    if opt['use_action']: results_str += '[use action]'
    if opt['use_curriculum']: results_str += '[use curriculum]'
    if opt['use_next']: results_str += '[use next]'
    if opt['use_residual']: results_str += '[use residual]'
    results_str += '\nsplit[%s]\n' % args.split

    results_str += '  nll_loss: %.3f\n' % overall_nll
    results_str += '  teacher-forcing acc (%s): %.2f%%,' % (
        len(predictions), overall_teacher_forcing_acc * 100.)
    results_str += ' on %s objects: %.2f%%,' % (
        Nav_cnt['object'], Nav_teacher_forcing_acc['object'] * 100.)
    results_str += ' on %s rooms: %.2f%%\n' % (
        Nav_cnt['room'], Nav_teacher_forcing_acc['room'] * 100.)

    # save
    with open(args.result_json, 'w') as f:
        json.dump(predictions, f)
    f = open(args.report_txt, 'a')
    f.write(results_str)
    f.write('\n')
    f.close()
예제 #2
0
def eval(rank, args, shared_nav_model, counter, split):
    # metric_names
    metric_names = [
        'd_0_5', 'd_0_10', 'd_0_15', 'd_T_5', 'd_T_10', 'd_T_15', 'd_D_5',
        'd_D_10', 'd_D_15', 'd_min_5', 'd_min_10', 'd_min_15', 'h_T_5',
        'h_T_10', 'h_T_15', 'r_T_5', 'r_T_10', 'r_T_15', 'r_e_5', 'r_e_10',
        'r_e_15', 'stop_5', 'stop_10', 'stop_15', 'ep_len_5', 'ep_len_10',
        'ep_len_15'
    ]

    # set up dataset
    cfg = {
        'colorFile':
        osp.join(args.house_meta_dir, 'colormap_fine.csv'),
        'roomTargetFile':
        osp.join(args.house_meta_dir, 'room_target_object_map.csv'),
        'modelCategoryFile':
        osp.join(args.house_meta_dir, 'ModelCategoryMapping.csv'),
        'prefix':
        args.house_data_dir,
    }
    loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': split,
        'max_seq_length': args.max_seq_length,
        'nav_types': ['room'],
        'gpu_id': 0,
        'cfg': cfg,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': 500,
        'pretrained_cnn_path': args.pretrained_cnn_path,
        'requires_imgs': False,
        'question_types': ['all'],
    }
    dataset = NavReinforceDataset(**loader_kwargs)
    eval_loader = DataLoader(dataset, batch_size=1, num_workers=0)
    print('eval_loader set up.')

    # set up model
    opt = vars(args)
    opt['act_to_ix'] = dataset.act_to_ix
    opt['num_actions'] = len(opt['act_to_ix'])
    model = Navigator(opt)
    print('eval navigator set up.')

    while True:
        # run evaluation for this epoch
        invalids = []

        # metrics
        nav_metrics = NavMetric(
            info={
                'split': split,
                'thread': rank
            },
            metric_names=metric_names,
            log_json=args.results_json,
        )

        # sync model (fixed since now on!)
        model.load_state_dict(shared_nav_model.state_dict())
        model.eval()
        model.cuda()

        # reset envs
        eval_loader.dataset._load_envs(start_idx=0, in_order=True)

        # run
        done = False  # current epoch is not done yet
        predictions = []
        while not done:
            for batch in tqdm(eval_loader):
                # load target_paths from available_idx (of some envs)
                # batch = {idx, qid, path_ix, house, id, type, phrase, phrase_emb, ego_feats,
                # action_inputs, action_outputs, action_masks, action_length}
                idx = batch['idx'][0].item()
                qid = batch['qid'][0].item()
                phrase_emb = batch['phrase_emb'].cuda()  # (1, 300) float
                raw_ego_feats = batch['ego_feats']  # (1, L, 3200) float
                action_inputs = batch['action_inputs']  # (1, L) int
                action_outputs = batch['action_outputs']  # (1, L) int
                action_length = batch['action_length'][0].item()
                tgt_id = batch['id'][0]
                tgt_type = batch['type'][0]
                pred = {
                    'qid': qid,
                    'house': batch['house'][0],
                    'id': batch['id'][0],
                    'type': batch['type'][0],
                    'path_name': batch['path_name'][0],
                    'path_ix': batch['path_ix'][0].item(),
                    'start_ix': batch['start_ix'][0].item(),
                    'key_ix': batch['key_ix'][0].item(),
                    'action_length': action_length,
                    'phrase': batch['phrase'][0],
                }
                metrics_slug = {}

                # evaluate at multiple initializations
                for i in [5, 10, 15]:

                    if action_length - i < 0:
                        invalids.append((idx, i))
                        continue

                    h3d = eval_loader.dataset.episode_house
                    episode_length = 0
                    episode_done = True
                    dists_to_target = []
                    pos_queue = []
                    actions = []

                    # forward through lstm till spawn
                    if len(eval_loader.dataset.episode_pos_queue[:-i]):
                        prev_pos_queue = eval_loader.dataset.episode_pos_queue[:
                                                                               -i]  # till spawned position
                        # ego_imgs = eval_loader.dataset.get_frames(h3d, prev_pos_queue, preprocess=True)
                        # ego_imgs = torch.from_numpy(ego_imgs).cuda()  # (l, 3, 224, 224)
                        # ego_feats = eval_loader.dataset.cnn.extract_feats(ego_imgs, conv4_only=True)  # (l, 3200)
                        # ego_feats = ego_feats.view(1, len(prev_pos_queue), 3200)  # (1, l, 3200)
                        ego_feats_pruned = raw_ego_feats[:, :len(
                            prev_pos_queue), :].cuda()  # (1, l, 3200)
                        action_inputs_pruned = action_inputs[:, :len(
                            prev_pos_queue)].cuda()  # (1, l)
                        _, _, _, state = model(
                            ego_feats_pruned, phrase_emb,
                            action_inputs_pruned)  # (1, l, rnn_size)
                        action = action_inputs[0, len(prev_pos_queue)].view(
                            -1).cuda()  # (1, )
                        init_pos = eval_loader.dataset.episode_pos_queue[-i]
                    else:
                        state = None
                        action = torch.LongTensor([
                            eval_loader.dataset.act_to_ix['dummy']
                        ]).cuda()  # (1, )
                        init_pos = eval_loader.dataset.episode_pos_queue[-i]

                    # spawn
                    h3d.env.reset(x=init_pos[0],
                                  y=init_pos[2],
                                  yaw=init_pos[3])
                    init_dist_to_target = h3d.get_dist_to_target(
                        h3d.env.cam.pos)
                    if init_dist_to_target < 0:  # unreachable
                        invalids.append([idx, i])
                        continue
                    dists_to_target += [init_dist_to_target]
                    pos_queue += [init_pos]

                    # act
                    ego_img = h3d.env.render()
                    ego_img = (
                        torch.from_numpy(ego_img.transpose(2, 0, 1)).float() /
                        255.).cuda()
                    ego_feat = eval_loader.dataset.cnn.extract_feats(
                        ego_img.unsqueeze(0), conv4_only=True)  # (1, 3200)
                    collision = False
                    rot_act = None
                    rot_act_cnt = 0
                    for step in range(args.max_episode_length):
                        # forward model
                        episode_length += 1
                        logprobs, state = model.forward_step(
                            ego_feat, phrase_emb, action,
                            state)  # (1, 4), (1, rnn_size)

                        # special case 1:
                        # if previous action is "forward" and collision happend and this action is still "forward", suppress it.
                        if action.item() == 0 and collision and torch.exp(
                                logprobs[0]).argmax().item() == 0:
                            logprobs[0][0] = -1e5

                        # special case 2:
                        # if spinned around 6 times for same rotation action, we suppress it
                        if torch.exp(logprobs[0]).argmax().item(
                        ) == rot_act and rot_act_cnt > 5:
                            logprobs[0][torch.exp(
                                logprobs[0]).argmax().item()] = -1e5

                        # sample action
                        action = torch.exp(logprobs[0]).argmax().item()
                        actions += [action]

                        # accumulate rot_act
                        if action == 0:
                            rot_act = None
                            rot_act_cnt = 0
                        elif action in [1, 2]:
                            if rot_act == action:
                                rot_act_cnt += 1
                            else:
                                rot_act = action
                                rot_act_cnt = 1

                        # interact with environment
                        ego_img, _, episode_done, collision = h3d.step(action)

                        # prepare state for next action
                        ego_img = (torch.from_numpy(ego_img.transpose(
                            2, 0, 1)).float() / 255.).cuda()
                        ego_feat = eval_loader.dataset.cnn.extract_feats(
                            ego_img.unsqueeze(0), conv4_only=True)  # (1, 3200)
                        action = torch.LongTensor([action]).cuda()  # (1, )

                        # add to result
                        dists_to_target.append(
                            h3d.get_dist_to_target(h3d.env.cam.pos))
                        pos_queue.append([
                            h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                            h3d.env.cam.pos.z, h3d.env.cam.yaw
                        ])

                        if episode_done:
                            break

                    # add to predictions
                    pred['d_%s' % str(i)] = {
                        'actions':
                        actions,
                        'pos_queue':
                        pos_queue,
                        'gd_actions':
                        action_outputs[
                            0, len(prev_pos_queue):len(prev_pos_queue) +
                            i].tolist(),
                        'gd_pos_queue':
                        eval_loader.dataset.episode_pos_queue[-i:] +
                        [eval_loader.dataset.episode_pos_queue[-1]]
                    }

                    # is final view looking at target object?
                    if tgt_type == 'object':
                        iou = h3d.compute_target_iou(tgt_id)
                        R = 1 if h3d.compute_target_iou(tgt_id) >= 0.1 else 0
                        pred['iou_%s' % str(i)] = iou

                    # compute stats
                    metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                    metrics_slug['ep_len_' + str(i)] = episode_length
                    metrics_slug['stop_' + str(i)] = 1 if action == 3 else 0
                    if tgt_type == 'object':
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug[
                            'd_D_' +
                            str(i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = float(
                            np.array(dists_to_target).min())
                        metrics_slug['h_T_' + str(i)] = R
                    else:
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                # collate and update metrics
                metrics_list = []
                for i in nav_metrics.metric_names:
                    if i not in metrics_slug:
                        metrics_list.append(nav_metrics.metrics[
                            nav_metrics.metric_names.index(i)][0])
                    else:
                        metrics_list.append(metrics_slug[i])

                # update metrics
                if len(metrics_slug) > 0:
                    nav_metrics.update(metrics_list)
                    predictions.append(pred)

            print(nav_metrics.get_stat_string(mode=0))
            print('invalids', len(invalids))
            logging.info("EVAL: init_steps: {} metrics: {}".format(
                i, nav_metrics.get_stat_string(mode=0)))
            logging.info("EVAL: init_steps: {} invalids: {}".format(
                i, len(invalids)))

            # next environments
            eval_loader.dataset._load_envs()
            print("eval_loader pruned_env_set len: {}".format(
                len(eval_loader.dataset.pruned_env_set)))
            logging.info("eval_loader pruned_env_set len: {}".format(
                len(eval_loader.dataset.pruned_env_set)))
            if len(eval_loader.dataset.pruned_env_set) == 0:
                done = True

        # save results
        nav_metrics.dump_log(predictions)

        # write summary
        results_str = 'id[%s] split[%s]\n' % (args.id, args.split)
        for n, r in zip(metric_names, nav_metrics.stats[-1]):
            if r:
                results_str += '%10s: %6.2f,' % (n, r[0])
                if '_15' in n:
                    results_str += '\n'
        f = open(args.report_txt, 'a')
        f.write(results_str)
        f.write('\n')
        f.close()

        # break
        break
예제 #3
0
def imitation(rank, args, shared_nav_model, counter):
    # set up tensorboard
    writer = tb.SummaryWriter(args.tb_dir, filename_suffix=str(rank))

    # set up cuda device
    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    # set up random seeds
    random.seed(args.seed + rank)
    np.random.randn(args.seed + rank)
    torch.manual_seed(args.seed + rank)

    # set up loader
    train_loader_kwargs = {
        'data_json': args.imitation_data_json,
        'data_h5': args.imitation_data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': 'train',
        'max_seq_length': args.max_seq_length,
        'requires_imgs': False,
        'nav_types': args.nav_types,
        'question_types': ['all'],
    }
    train_dataset = NavImitationDataset(**train_loader_kwargs)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=1)
    print('train_loader set up.')

    # set up optimizer on shared_nav_model
    # lr = 5e-5
    lr = args.learning_rate
    optimizer = torch.optim.Adam(shared_nav_model.parameters(),
                                 lr=lr,
                                 betas=(args.optim_alpha, args.optim_beta),
                                 eps=args.optim_epsilon,
                                 weight_decay=args.weight_decay)

    # set up models
    opt = vars(args)
    opt['act_to_ix'] = train_dataset.act_to_ix
    opt['num_actions'] = len(opt['act_to_ix'])
    model = Navigator(opt)
    model.cuda()
    print('navigator set up.')

    # set up criterions
    nll_crit = SeqModelCriterion().cuda()

    # -
    epoch = 0
    iters = 0

    # train
    while True:

        for batch in train_loader:
            # sync model
            model.load_state_dict(shared_nav_model.state_dict())
            model.train()
            model.cuda()

            # batch = {qid, path_ix, house, id, type, phrase, phrase_emb, ego_feats, next_feats, res_feats,
            #  action_inputs, action_outputs, action_masks, ego_imgs}
            ego_feats = batch['ego_feats'].cuda()  # (n, L, 3200)
            phrase_embs = batch['phrase_emb'].cuda()  # (n, 300)
            action_inputs = batch['action_inputs'].cuda()  # (n, L)
            action_outputs = batch['action_outputs'].cuda()  # (n, L)
            action_masks = batch['action_masks'].cuda()  # (n, L)
            # forward
            # - logprobs (n, L, #actions)
            # - output_feats (n, L, rnn_size)
            # - pred_feats (n, L, 3200) or None
            logprobs, _, pred_feats, _ = model(ego_feats, phrase_embs,
                                               action_inputs)
            nll_loss = nll_crit(logprobs, action_outputs, action_masks)

            # backward
            optimizer.zero_grad()
            nll_loss.backward()
            clip_model_gradient(model.parameters(), args.grad_clip)
            ensure_shared_grads(model.cpu(), shared_nav_model)
            optimizer.step()

            if iters % 25 == 0:
                print('imitation-r%s(ep%s it%s lr%.2E loss%.4f)' %
                      (rank, epoch, iters, lr, nll_loss))

            # write to tensorboard
            writer.add_scalar('imitation_rank/nll_loss',
                              float(nll_loss.item()), counter.value)

            # increate iters
            iters += 1

            # decay learning rate
            if args.lr_decay > 0:
                if args.im_learning_rate_decay_start > 0 and iters > args.im_learning_rate_decay_start:
                    frac = (iters - args.im_learning_rate_decay_start
                            ) / args.im_learning_rate_decay_every
                    decay_factor = 0.1**frac
                    lr = args.learning_rate * decay_factor
                    model_utils.set_lr(optimizer, lr)

        epoch += 1
예제 #4
0
def eval(rank, args, shared_nav_model, counter, split='val'):
    # set up cuda device
    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    # tensorboard
    # writer = tb.SummaryWriter(log_dir=args.tb_dir, filename_suffix=str(rank))
    writer = tb.SummaryWriter(osp.join(args.tb_dir, str(rank)))

    # set up dataset
    cfg = {
        'colorFile':
        osp.join(args.house_meta_dir, 'colormap_fine.csv'),
        'roomTargetFile':
        osp.join(args.house_meta_dir, 'room_target_object_map.csv'),
        'modelCategoryFile':
        osp.join(args.house_meta_dir, 'ModelCategoryMapping.csv'),
        'prefix':
        args.house_data_dir,
    }
    loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': split,
        'max_seq_length': args.max_seq_length,
        'nav_types': args.nav_types,
        'gpu_id': 0,
        'cfg': cfg,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': 500,
        'pretrained_cnn_path': args.pretrained_cnn_path,
        'requires_imgs': False,
        'question_types': ['all'],
    }
    dataset = NavReinforceDataset(**loader_kwargs)
    eval_loader = DataLoader(dataset, batch_size=1, num_workers=0)
    print('eval_loader set up.')

    # set up model
    opt = vars(args)
    opt['act_to_ix'] = dataset.act_to_ix
    opt['num_actions'] = len(opt['act_to_ix'])
    model = Navigator(opt)
    print('eval navigator set up.')

    # evaluate
    epoch = 0
    best_eval_score = None
    while epoch < int(args.max_epochs):
        # wait till counter.value >= epoch * num_iters_per_epoch
        cur_counter_value = counter.value
        if counter.value / args.num_iters_per_epoch >= epoch:
            epoch += 1
        else:
            continue

        # run evaluation for this epoch
        invalids = []

        # metrics
        nav_metrics = NavMetric(
            info={
                'split': split,
                'thread': rank
            },
            metric_names=metric_names,
            log_json=osp.join(args.log_dir, 'nav_eval_val.json'),
        )

        # update model (fixed since now on!)
        model.load_state_dict(shared_nav_model.state_dict())
        model.eval()
        model.cuda()

        # reset envs
        eval_loader.dataset._load_envs(start_idx=0, in_order=True)
        eval_loader.dataset.visited_envs = set()

        # run
        done = False
        predictions = []
        while not done:
            for batch in tqdm(eval_loader):
                # load target_paths from available_idx (of some envs)
                # batch = {idx, qid, path_ix, house, id, type, phrase, phrase_emb, ego_feats,
                # action_inputs, action_outputs, action_masks, action_length}
                idx = batch['idx'][0].item()
                qid = batch['qid'][0].item()
                phrase_emb = batch['phrase_emb'].cuda()  # (1, 300) cuda
                raw_ego_feats = batch['ego_feats']  # (1, L, 3200) float
                action_inputs = batch['action_inputs']  # (1, L) int
                action_outputs = batch['action_outputs']  # (1, L) int
                action_length = batch['action_length'][0].item()
                tgt_type = batch['type'][0]
                tgt_id = batch['id'][0]
                pred = {
                    'qid': qid,
                    'house': batch['house'][0],
                    'id': batch['id'][0],
                    'type': batch['type'][0],
                    'path_name': batch['path_name'][0],
                    'path_ix': batch['path_ix'][0].item(),
                    'start_ix': batch['start_ix'][0].item(),
                    'key_ix': batch['key_ix'][0].item(),
                    'action_length': action_length,
                    'phrase': batch['phrase'][0],
                }
                metrics_slug = {}

                # evaluate at multiple initializations
                # for i in [5, 10, 15]:
                for i in [15]:  # for saving evaluation time

                    if action_length - i < 0:
                        invalids.append((idx, i))
                        continue

                    h3d = eval_loader.dataset.episode_house
                    episode_length = 0
                    episode_done = True
                    dists_to_target = []
                    pos_queue = []
                    actions = []

                    # forward through navigator till spawn
                    if len(eval_loader.dataset.episode_pos_queue[:-i]):
                        prev_pos_queue = eval_loader.dataset.episode_pos_queue[:
                                                                               -i]  # till spawned position
                        # ego_imgs = eval_loader.dataset.get_frames(h3d, prev_pos_queue, preprocess=True)
                        # ego_imgs = torch.from_numpy(ego_imgs).cuda()  # (l, 3, 224, 224)
                        # ego_feats = eval_loader.dataset.cnn.extract_feats(ego_imgs, conv4_only=True)  # (l, 3200)
                        # ego_feats = ego_feats.view(1, len(prev_pos_queue), 3200)  # (1, l, 3200)
                        ego_feats_pruned = raw_ego_feats[:, :len(
                            prev_pos_queue), :].cuda()  # (1, l, 3200)
                        action_inputs_pruned = action_inputs[:, :len(
                            prev_pos_queue)].cuda()  # (1, l)
                        _, _, _, state = model(
                            ego_feats_pruned, phrase_emb,
                            action_inputs_pruned)  # (1, l, rnn_size)
                        action = action_inputs[0, len(prev_pos_queue)].view(
                            -1).cuda()  # (1, )
                        init_pos = eval_loader.dataset.episode_pos_queue[-i]
                    else:
                        state = None
                        action = torch.LongTensor([
                            eval_loader.dataset.act_to_ix['dummy']
                        ]).cuda()  # (1, )
                        init_pos = eval_loader.dataset.episode_pos_queue[
                            0]  # use first position instead

                    # spawn
                    h3d.env.reset(x=init_pos[0],
                                  y=init_pos[2],
                                  yaw=init_pos[3])
                    init_dist_to_target = h3d.get_dist_to_target(
                        h3d.env.cam.pos)
                    if init_dist_to_target < 0:  # unreachable
                        invalids.append([idx, i])
                        continue
                    dists_to_target += [init_dist_to_target]
                    pos_queue += [init_pos]

                    # act
                    ego_img = h3d.env.render()
                    ego_img = (
                        torch.from_numpy(ego_img.transpose(2, 0, 1)).float() /
                        255.).cuda()
                    ego_feat = eval_loader.dataset.cnn.extract_feats(
                        ego_img.unsqueeze(0), conv4_only=True)  # (1, 3200)
                    collision = False
                    rot_act = None
                    rot_act_cnt = 0
                    for step in range(args.max_episode_length):
                        # forward model
                        episode_length += 1
                        with torch.no_grad():
                            logprobs, state = model.forward_step(
                                ego_feat, phrase_emb, action,
                                state)  # (1, 4), (1, rnn_size)

                        # special case 1: if previous action is "forward" and collision happend and this action
                        # is still "forward", suppress it.
                        if action.item() == 0 and collision and torch.exp(
                                logprobs[0]).argmax().item() == 0:
                            logprobs[0][0] = -1e5

                        # special case 2:
                        # if spinned around 6 times for same rotation action, we suppress it
                        if torch.exp(logprobs[0]).argmax().item(
                        ) == rot_act and rot_act_cnt > 5:
                            logprobs[0][torch.exp(
                                logprobs[0]).argmax().item()] = -1e5

                        # sample action
                        action = torch.exp(logprobs[0]).argmax().item()
                        actions += [action]

                        # accumulate rot_act
                        if action == 0:
                            rot_act = None
                            rot_act_cnt = 0
                        elif action in [1, 2]:
                            if rot_act == action:
                                rot_act_cnt += 1
                            else:
                                rot_act = action
                                rot_act_cnt = 1

                        # interact with environment
                        ego_img, _, episode_done, collision = h3d.step(action)
                        episode_done = episode_done or episode_length >= args.max_episode_length  # no need actually

                        # prepare state for next action
                        ego_img = (torch.from_numpy(ego_img.transpose(
                            2, 0, 1)).float() / 255.).cuda()
                        ego_feat = eval_loader.dataset.cnn.extract_feats(
                            ego_img.unsqueeze(0), conv4_only=True)  # (1, 3200)
                        action = torch.LongTensor([action]).cuda()  # (1, )

                        # add to result
                        dists_to_target.append(
                            h3d.get_dist_to_target(h3d.env.cam.pos))
                        pos_queue.append([
                            h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                            h3d.env.cam.pos.z, h3d.env.cam.yaw
                        ])

                        if episode_done:
                            break

                    # add to predictions
                    pred['d_%s' % str(i)] = {
                        'actions':
                        actions,
                        'pos_queue':
                        pos_queue,
                        'gd_actions':
                        action_outputs[
                            0, len(prev_pos_queue):len(prev_pos_queue) +
                            i].tolist(),
                        'gd_pos_queue':
                        eval_loader.dataset.episode_pos_queue[-i:] +
                        [eval_loader.dataset.episode_pos_queue[-1]]
                    }

                    # compute stats
                    metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                    metrics_slug['ep_len_' + str(i)] = episode_length
                    metrics_slug['stop_' + str(i)] = 1 if action == 3 else 0
                    if tgt_type == 'object':
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug[
                            'd_D_' +
                            str(i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = float(
                            np.array(dists_to_target).min())
                        iou = h3d.compute_target_iou(tgt_id)
                        metrics_slug['h_T_' + str(i)] = 1 if iou >= 0.1 else 0
                    else:
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        metrics_slug['r_T_' +
                                     str(i)] = 1 if inside_room[-1] else 0
                        metrics_slug['r_e_' + str(i)] = 1 if any(
                            [x == True for x in inside_room]) else 0

                # collate and update metrics
                metrics_list = []
                for name in nav_metrics.metric_names:
                    if name not in metrics_slug:
                        metrics_list.append(nav_metrics.metrics[
                            nav_metrics.metric_names.index(name)][0])
                    else:
                        metrics_list.append(metrics_slug[name])

                # update metrics
                if len(metrics_slug) > 0:
                    nav_metrics.update(metrics_list)
                    predictions.append(pred)

            print(nav_metrics.get_stat_string(mode=0))
            print('invalids', len(invalids))
            logging.info("EVAL: init_steps: {} metrics: {}".format(
                i, nav_metrics.get_stat_string(mode=0)))
            logging.info("EVAL: init_steps: {} invalids: {}".format(
                i, len(invalids)))
            print('%s/%s envs visited.' %
                  (len(eval_loader.dataset.visited_envs),
                   len(eval_loader.dataset.env_set)))

            # next environments
            eval_loader.dataset._load_envs()
            print("eval_loader pruned_env_set len: {}".format(
                len(eval_loader.dataset.pruned_env_set)))
            logging.info("eval_loader pruned_env_set len: {}".format(
                len(eval_loader.dataset.pruned_env_set)))
            if len(eval_loader.dataset.pruned_env_set) == 0:
                done = True

        # write to tensorboard
        for metric_name in [
                'd_T_5', 'd_T_10', 'd_T_15', 'd_D_5', 'd_D_10', 'd_D_15',
                'h_T_5', 'h_T_10', 'h_T_15', 'r_T_5', 'r_T_10', 'r_T_15',
                'r_e_5', 'r_e_10', 'r_e_15'
        ]:
            value = nav_metrics.metrics[nav_metrics.metric_names.index(
                metric_name)][0]
            if value:
                # instead of counter.value (as we started eval at cur_counter_value)
                writer.add_scalar('eval/%s' % metric_name, value,
                                  cur_counter_value)

        # save if best
        # best_score = d_D_15 + h_T_15 + r_T_15
        cur_score = 0
        if nav_metrics.metrics[nav_metrics.metric_names.index('d_D_15')][0]:
            cur_score += nav_metrics.metrics[nav_metrics.metric_names.index(
                'd_D_15')][0]
        if nav_metrics.metrics[nav_metrics.metric_names.index('h_T_15')][0]:
            cur_score += nav_metrics.metrics[nav_metrics.metric_names.index(
                'h_T_15')][0]
        if nav_metrics.metrics[nav_metrics.metric_names.index('r_T_15')][0]:
            cur_score += nav_metrics.metrics[nav_metrics.metric_names.index(
                'r_T_15')][0]
        if (best_eval_score is None
                or cur_score > best_eval_score) and cur_counter_value > 50000:
            best_eval_score = cur_score
            nav_metrics.dump_log(predictions)
            checkpoint_path = osp.join(args.checkpoint_dir, '%s.pth' % args.id)
            checkpoint = {}
            checkpoint['model_state'] = model.cpu().state_dict()
            checkpoint['opt'] = vars(args)
            torch.save(checkpoint, checkpoint_path)
            print('model saved to %s.' % checkpoint_path)
예제 #5
0
def main(args):
    # make output directory
    if args.checkpoint_dir is None:
        args.checkpoint_dir = 'output/nav_object'
    if not osp.isdir(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    # set random seed
    random.seed(args.seed)
    np.random.randn(args.seed)
    torch.manual_seed(args.seed)

    # set up loaders
    train_loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': 'train',
        'max_seq_length': args.max_seq_length,
        'requires_imgs': False,
        'nav_types': ['object'],
        'question_types': ['all'],
    }
    train_dataset = NavImitationDataset(**train_loader_kwargs)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers)
    val_loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': 'val',
        'max_seq_length': args.max_seq_length,
        'requires_imgs': False,
        'nav_types': ['object'],
        'question_types': ['all'],
    }
    val_dataset = NavImitationDataset(**val_loader_kwargs)

    # set up models
    opt = vars(args)
    opt['act_to_ix'] = train_dataset.act_to_ix
    opt['num_actions'] = len(opt['act_to_ix'])
    model = Navigator(opt)
    model.cuda()
    print('navigator set up.')

    # set up criterions
    nll_crit = SeqModelCriterion().cuda()
    mse_crit = MaskedMSELoss().cuda()

    # set up optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 betas=(args.optim_alpha, args.optim_beta),
                                 eps=args.optim_epsilon,
                                 weight_decay=args.weight_decay)

    # resume from checkpoint
    infos = {}
    iters = infos.get('iters', 0)
    epoch = infos.get('epoch', 0)
    val_nll_history = infos.get('val_nll_history', {})
    val_mse_history = infos.get('val_mse_history', {})
    val_teacher_forcing_acc_history = infos.get(
        'val_teacher_forcing_acc_history', {})
    val_nav_object_nll_history = infos.get('val_nav_object_nll_history', {})
    val_nav_object_teacher_forcing_acc_history = infos.get(
        'val_nav_object_teacher_forcing_acc_history', {})
    val_nav_room_nll_history = infos.get('val_nav_room_nll_history', {})
    val_nav_room_teacher_forcing_acc_history = infos.get(
        'val_nav_room_teacher_forcing_acc_history', {})
    loss_history = infos.get('loss_history', {})
    nll_loss_history = infos.get('nll_loss_history', {})
    mse_loss_history = infos.get('mse_loss_history', {})
    lr = infos.get('lr', args.learning_rate)
    best_val_score, best_val_acc, best_predictions = None, None, None

    # start training
    while iters <= args.max_iters:
        print('Starting epoch %d' % epoch)
        # reset seq_length
        if args.use_curriculum:
            # assume we need 4 epochs to get full seq_length
            seq_length = min((args.max_seq_length // 4)**(epoch + 1),
                             args.max_seq_length)
            train_dataset.reset_seq_length(seq_length)
        else:
            seq_length = args.max_seq_length
        # train
        for batch in train_loader:
            # set mode
            model.train()
            # zero gradient
            optimizer.zero_grad()
            # batch = {qid, path_ix, house, id, type, phrase, phrase_emb, ego_feats, next_feats, res_feats,
            #  action_inputs, action_outputs, action_masks, ego_imgs}
            ego_feats = batch['ego_feats'].cuda()  # (n, L, 3200)
            phrase_embs = batch['phrase_emb'].cuda()  # (n, 300)
            action_inputs = batch['action_inputs'].cuda()  # (n, L)
            action_outputs = batch['action_outputs'].cuda()  # (n, L)
            action_masks = batch['action_masks'].cuda()  # (n, L)
            # forward
            # - logprobs (n, L, #actions)
            # - output_feats (n, L, rnn_size)
            # - pred_feats (n, L, 3200) or None
            logprobs, _, pred_feats, _ = model(ego_feats, phrase_embs,
                                               action_inputs)
            nll_loss = nll_crit(logprobs, action_outputs, action_masks)
            mse_loss = 0
            if args.use_next:
                next_feats = batch['next_feats'].cuda()  # (n, L, 3200)
                mse_loss = mse_crit(pred_feats, next_feats, action_masks)
            if args.use_residual:
                res_feats = batch['res_feats'].cuda()  # (n, L, 3200)
                mse_loss = mse_crit(pred_feats, res_feats, action_masks)
            loss = nll_loss + args.mse_weight * mse_loss
            # backward
            loss.backward()
            model_utils.clip_gradient(optimizer, args.grad_clip)
            optimizer.step()

            # training log
            if iters % args.losses_log_every == 0:
                loss_history[iters] = loss.item()
                nll_loss_history[iters] = nll_loss.item()
                mse_loss_history[iters] = mse_loss.item() if (
                    args.use_next or args.use_residual) else 0
                print('iters[%s]epoch[%s], train_loss=%.3f (nll_loss=%.3f, mse_loss=%.3f) lr=%.2E, cur_seq_length=%s' % \
                  (iters, epoch, loss_history[iters], nll_loss_history[iters], mse_loss_history[iters], lr, train_loader.dataset.cur_seq_length))

            # decay learning rate
            if args.learning_rate_decay_start > 0 and iters > args.learning_rate_decay_start:
                frac = (iters - args.learning_rate_decay_start
                        ) / args.learning_rate_decay_every
                decay_factor = 0.1**frac
                lr = args.learning_rate * decay_factor
                model_utils.set_lr(optimizer, lr)

            # evaluate
            if iters % args.save_checkpoint_every == 0:
                print('Checking validation ...')
                predictions, overall_nll, overall_teacher_forcing_acc, overall_mse, Nav_nll, Nav_teacher_forcing_acc = \
                  evaluate(val_dataset, model, nll_crit, mse_crit, opt)
                val_nll_history[iters] = overall_nll
                val_teacher_forcing_acc_history[
                    iters] = overall_teacher_forcing_acc
                val_mse_history[iters] = overall_mse
                val_nav_object_nll_history[iters] = Nav_nll['object']
                val_nav_object_teacher_forcing_acc_history[
                    iters] = Nav_teacher_forcing_acc['object']
                val_nav_room_nll_history[iters] = Nav_nll['room']
                val_nav_room_teacher_forcing_acc_history[
                    iters] = Nav_teacher_forcing_acc['room']

                # save model if best
                # consider all three accuracy, perhaps a better weighting is needed.
                current_score = -overall_nll
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_val_acc = overall_teacher_forcing_acc
                    best_predictions = predictions
                    checkpoint_path = osp.join(args.checkpoint_dir,
                                               '%s.pth' % args.id)
                    checkpoint = {}
                    checkpoint['model_state'] = model.state_dict()
                    checkpoint['opt'] = vars(args)
                    torch.save(checkpoint, checkpoint_path)
                    print('model saved to %s.' % checkpoint_path)

                # write to json report
                infos['iters'] = iters
                infos['epoch'] = epoch
                infos['loss_history'] = loss_history
                infos['nll_loss_history'] = nll_loss_history
                infos['mse_loss_history'] = mse_loss_history
                infos['val_nll_history'] = val_nll_history
                infos[
                    'val_teacher_forcing_acc_history'] = val_teacher_forcing_acc_history
                infos['val_mse_history'] = val_mse_history
                infos[
                    'val_nav_object_nll_history'] = val_nav_object_nll_history
                infos[
                    'val_nav_object_teacher_forcing_acc_history'] = val_nav_object_teacher_forcing_acc_history
                infos['val_nav_room_nll_history'] = val_nav_room_nll_history
                infos[
                    'val_nav_room_teacher_forcing_acc_history'] = val_nav_room_teacher_forcing_acc_history
                infos['best_val_score'] = best_val_score
                infos['best_val_acc'] = best_val_acc
                infos[
                    'best_predictions'] = predictions if best_predictions is None else best_predictions
                infos['opt'] = vars(args)
                infos['act_to_ix'] = train_dataset.act_to_ix
                infos_json = osp.join(args.checkpoint_dir, '%s.json' % args.id)
                with open(infos_json, 'w') as f:
                    json.dump(infos, f)
                print('infos saved to %s.' % infos_json)

            # update iters
            iters += 1

        # update epoch
        epoch += 1
예제 #6
0
def train(rank, args, shared_nav_model, counter, lock):
    # set up tensorboard
    # writer = tb.SummaryWriter(args.tb_dir, filename_suffix=str(rank))
    writer = tb.SummaryWriter(osp.join(args.tb_dir, str(rank)))

    # set up cuda device
    gpu_id = args.gpus.index(args.gpus[rank % len(args.gpus)])
    torch.cuda.set_device(gpu_id)

    # set up random seeds
    random.seed(args.seed + rank)
    np.random.randn(args.seed + rank)
    torch.manual_seed(args.seed + rank)

    # set up dataset
    cfg = {
        'colorFile':
        osp.join(args.house_meta_dir, 'colormap_fine.csv'),
        'roomTargetFile':
        osp.join(args.house_meta_dir, 'room_target_object_map.csv'),
        'modelCategoryFile':
        osp.join(args.house_meta_dir, 'ModelCategoryMapping.csv'),
        'prefix':
        args.house_data_dir,
    }
    loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': 'train',
        'max_seq_length': args.max_seq_length,
        'nav_types': args.nav_types,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'cfg': cfg,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': 500,
        'pretrained_cnn_path': args.pretrained_cnn_path,
        'requires_imgs': False,
        'question_types': ['all'],
        'ratio': [rank / args.num_processes, (rank + 1) / args.num_processes]
    }
    dataset = NavReinforceDataset(**loader_kwargs)
    train_loader = DataLoader(dataset, batch_size=1, num_workers=0)
    print('train_loader set up.')

    # set up optimizer on shared_nav_model
    lr = args.learning_rate
    optimizer = torch.optim.Adam(shared_nav_model.parameters(),
                                 lr=lr,
                                 betas=(args.optim_alpha, args.optim_beta),
                                 eps=args.optim_epsilon,
                                 weight_decay=args.weight_decay)

    # set up model
    opt = vars(args)
    opt['act_to_ix'] = dataset.act_to_ix
    opt['num_actions'] = len(opt['act_to_ix'])
    model = Navigator(opt)
    print('navigator[%s] set up.' % rank)

    # set up metrics outside epoch, as we use running rewards through whole training
    nav_metrics = NavMetric(
        info={
            'split': 'train',
            'thread': rank
        },
        metric_names=['reward', 'episode_length'],
        log_json=osp.join(args.log_dir, 'nav_train_' + str(rank) + '.json'),
    )
    nav_metrics.update([0, 100])
    reward_list, episode_length_list = [], []
    rwd_stats = Statistics()  # computing running mean and std of rewards

    # path length multiplier
    min_dist = 5 if args.nav_types == [
        'object'
    ] else 15  # 15 for rl3 and rl5, 10 for rl4
    # max_dist = 25 if args.nav_types == ['object'] else 40
    max_dist = 35 if args.nav_types == ['object'] else 50
    mult = 0.1 if args.nav_types == ['object'] else 0.15
    rwd_thresh = 0.1 if args.nav_types == ['object'] else 0.00
    epoch = 0
    iters = 0

    # train
    while True:

        # reset envs
        train_loader.dataset._load_envs(start_idx=0, in_order=True)
        done = False  # current epoch is not done yet
        while not done:

            for batch in train_loader:
                # sync model
                model.load_state_dict(shared_nav_model.state_dict())
                model.train()
                model.cuda()

                # load target_paths from available_idx (of some envs)
                # batch = {idx, qid, path_ix, house, id, type, phrase, phrase_emb, ego_feats,
                # action_inputs, action_outputs, action_masks, action_length}
                idx = batch['idx'][0].item()
                qid = batch['qid'][0].item()
                phrase_emb = batch['phrase_emb'].cuda()  # (1, 300) float
                raw_ego_feats = batch['ego_feats']  # (1, L, 3200) float
                action_inputs = batch['action_inputs']  # (1, L) int
                action_outputs = batch['action_outputs']  # (1, L) int
                action_length = batch['action_length'][0].item()
                tgt_type = batch['type'][0]
                tgt_id = batch['id'][0]
                tgt_phrase = batch['phrase'][0]

                # to be recorded
                episode_length = 0
                episode_done = True
                dists_to_target = []
                pos_queue = []
                actions = []
                rewards = []
                nav_log_probs = []

                # spawn agent
                h3d = train_loader.dataset.episode_house
                if np.random.uniform(0, 1, 1)[0] <= args.shortest_path_ratio:
                    # half chance we use shortest path to spawn the agent (if vlen > 0, i.e., shortest path long enough)
                    use_shortest_path = True
                    vlen = min(max(min_dist, int(mult * action_length)),
                               action_length)
                    # forward throught navigator till spawn
                    if len(train_loader.dataset.episode_pos_queue) > vlen:
                        prev_pos_queue = train_loader.dataset.episode_pos_queue[:
                                                                                -vlen]  # till spawned position
                        ego_feats_pruned = raw_ego_feats[:, :len(
                            prev_pos_queue), :].cuda()  # (1, l, 3200)
                        action_inputs_pruned = action_inputs[:, :len(
                            prev_pos_queue)].cuda()  # (1, l)
                        _, _, _, state = model(
                            ego_feats_pruned, phrase_emb,
                            action_inputs_pruned)  # (1, l, rnn_size)
                        action = action_inputs[0, len(prev_pos_queue)].view(
                            -1).cuda()  # (1, )
                        init_pos = train_loader.dataset.episode_pos_queue[
                            -vlen]
                    else:
                        state = None
                        action = torch.LongTensor([
                            train_loader.dataset.act_to_ix['dummy']
                        ]).cuda()  # (1, )
                        init_pos = train_loader.dataset.episode_pos_queue[
                            0]  # use first position of the path
                else:
                    # half chance we randomly spawn agent
                    use_shortest_path = False
                    state = None
                    action = torch.LongTensor([
                        train_loader.dataset.act_to_ix['dummy']
                    ]).cuda()  # (1, )
                    init_pos, vlen = train_loader.dataset.spawn_agent(
                        min_dist, max_dist)
                    if init_pos is None:  # init_pos not found
                        continue

                # initiate
                h3d.env.reset(x=init_pos[0], y=init_pos[2], yaw=init_pos[3])
                init_dist_to_target = h3d.get_dist_to_target(h3d.env.cam.pos)
                if init_dist_to_target < 0:  # unreachable
                    continue
                dists_to_target += [init_dist_to_target]
                pos_queue += [init_pos]

                # act
                ego_img = h3d.env.render()
                ego_img = (
                    torch.from_numpy(ego_img.transpose(2, 0, 1)).float() /
                    255.).cuda()
                ego_feat = train_loader.dataset.cnn.extract_feats(
                    ego_img.unsqueeze(0), conv4_only=True)  # (1, 3200)
                prev_action, collision = None, False
                for step in range(args.max_episode_length):
                    # forward model one step
                    episode_length += 1
                    logprobs, state = model.forward_step(
                        ego_feat, phrase_emb, action,
                        state)  # (1, 4), (1, rnn_size)

                    # sample action
                    probs = torch.exp(logprobs)  # (1, 4)
                    action = probs.multinomial(
                        num_samples=1).detach()  # (1, 1)
                    if prev_action == 0 and collision and action[0][0].item(
                    ) == 0:
                        # special case: prev_action == "forward" && collision && cur_action == "forward"
                        # we sample from {'left', 'right', 'stop'} only
                        action = probs[0:1, 1:].multinomial(
                            num_samples=1).detach() + 1  # (1, 1)

                    if len(pos_queue) < min_dist:
                        # special case: our room navigator tends to stop early, let's push it to explore longer
                        action = probs[0:1, :3].multinomial(
                            num_samples=1).detach()  # (1, 1)

                    nav_log_probs.append(logprobs.gather(1, action))  # (1, 1)
                    action = action.view(-1)  # (1, )
                    actions.append(action)

                    # interact with environment
                    ego_img, reward, episode_done, collision = h3d.step(
                        action[0].item(), step_reward=True)
                    if not episode_done:
                        reward -= 0.01  # we don't wanna too long trajectory
                    episode_done = episode_done or episode_length >= args.max_episode_length  # no need actually
                    reward = max(min(reward, 1), -1)
                    rewards.append(reward)
                    prev_action = action[0].item()

                    # prepare state for next action
                    ego_img = (
                        torch.from_numpy(ego_img.transpose(2, 0, 1)).float() /
                        255.).cuda()
                    ego_feat = train_loader.dataset.cnn.extract_feats(
                        ego_img.unsqueeze(0), conv4_only=True)  # (1, 3200)

                    # add to result
                    dists_to_target.append(
                        h3d.get_dist_to_target(h3d.env.cam.pos))
                    pos_queue.append([
                        h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                        h3d.env.cam.pos.z, h3d.env.cam.yaw
                    ])

                    if episode_done:
                        break

                # final reward
                R = 0
                if tgt_type == 'object':
                    R = 1.0 if h3d.compute_target_iou(tgt_id) >= 0.1 else -1.0
                else:
                    R = 0.2 if h3d.is_inside_room(
                        pos_queue[-1],
                        train_loader.dataset.target_room) else -0.2
                    if R > 0 and h3d.compute_room_targets_iou(
                            room_to_objects[tgt_phrase]) > 0.1:
                        R += 1.0  # encourage agent to move to room-specific targets

                # backward
                nav_loss = 0
                new_rewards = [
                ]  # recording reshaped rewards, to be computed as moving average
                advantages = []
                for i in reversed(range(len(rewards))):
                    R = 0.99 * R + rewards[i]
                    new_rewards.insert(0, R)
                    rwd_stats.push(R)
                for nav_log_prob, R in zip(nav_log_probs, new_rewards):
                    advantage = (R - rwd_stats.mean()) / (
                        rwd_stats.stddev() + 1e-5)  # rl2, rl3
                    # advantage = R - rwd_stats.mean()  # rl1
                    nav_loss = nav_loss - nav_log_prob * advantage
                    advantages.insert(0, advantage)
                nav_loss /= max(1, len(nav_log_probs))

                optimizer.zero_grad()
                nav_loss.backward()
                clip_model_gradient(model.parameters(), args.grad_clip)
                # Till this point, we have grads on model's parameters!
                # but shared_nav_model's grads is still not updated
                ensure_shared_grads(model.cpu(), shared_nav_model)
                optimizer.step()

                if iters % 5 == 0:
                    log_info = 'train-r%2s(ep%2s it%5s lr%.2E mult%.2f, run_rwd%6.3f, bs_rwd%6.3f, bs_std%2.2f), vlen(%s)=%2s, rwd=%6.3f, ep_len=%3s, tgt:%s' % \
                      (rank, epoch, iters, lr, mult, nav_metrics.metrics[0][1], rwd_stats.mean(), rwd_stats.stddev(), 'short' if use_shortest_path else 'spawn', vlen, np.mean(new_rewards), len(new_rewards), tgt_phrase)
                    print(log_info)

                # update metrics
                reward_list += new_rewards
                episode_length_list.append(episode_length)
                if len(episode_length_list) > 50:
                    nav_metrics.update([reward_list, episode_length_list])
                    nav_metrics.dump_log()
                    reward_list, episode_length_list = [], []

                # write to tensorboard
                writer.add_scalar('train_rank%s/nav_loss' % rank,
                                  nav_loss.item(), counter.value)
                writer.add_scalar('train_rank%s/steps_avg_rwd' % rank,
                                  float(np.mean(new_rewards)), counter.value)
                writer.add_scalar('train_rank%s/steps_sum_rwd' % rank,
                                  float(np.sum(new_rewards)), counter.value)
                writer.add_scalar('train_rank%s/advantage' % rank,
                                  float(np.mean(advantages)), counter.value)

                # increase counter as this episode ends
                with lock:
                    counter.value += 1

                # increase mult
                iters += 1
                if nav_metrics.metrics[0][
                        1] > rwd_thresh:  # baseline = nav_metrics.metrics[0][1]
                    mult = min(mult + 0.1, 1.0)
                    rwd_thresh += 0.01
                else:
                    mult = max(mult - 0.1, 0.1)
                    rwd_thresh -= 0.01
                rwd_thresh = max(0.1, min(rwd_thresh, 0.2))

                # decay learning rate
                if args.learning_rate_decay_start > 0 and iters > args.learning_rate_decay_start:
                    frac = (iters - args.learning_rate_decay_start
                            ) / args.learning_rate_decay_every
                    decay_factor = 0.1**frac
                    lr = args.learning_rate * decay_factor
                    model_utils.set_lr(optimizer, lr)

            # next environments
            train_loader.dataset._load_envs(in_order=True)
            print("train_loader pruned_env_set len: {}".format(
                len(train_loader.dataset.pruned_env_set)))
            logging.info("train_loader pruned_env_set len: {}".format(
                len(train_loader.dataset.pruned_env_set)))
            if len(train_loader.dataset.pruned_env_set) == 0:
                done = True

        epoch += 1