Ejemplo n.º 1
0
def main(args):
    # make output directory
    if args.checkpoint_dir is None:
        args.checkpoint_dir = 'output/nav_object'
    if not osp.isdir(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    # set random seed
    random.seed(args.seed)
    np.random.randn(args.seed)
    torch.manual_seed(args.seed)

    # set up loaders
    train_loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': 'train',
        'max_seq_length': args.max_seq_length,
        'requires_imgs': False,
        'nav_types': ['object'],
        'question_types': ['all'],
    }
    train_dataset = NavImitationDataset(**train_loader_kwargs)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers)
    val_loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': 'val',
        'max_seq_length': args.max_seq_length,
        'requires_imgs': False,
        'nav_types': ['object'],
        'question_types': ['all'],
    }
    val_dataset = NavImitationDataset(**val_loader_kwargs)

    # set up models
    opt = vars(args)
    opt['act_to_ix'] = train_dataset.act_to_ix
    opt['num_actions'] = len(opt['act_to_ix'])
    model = Navigator(opt)
    model.cuda()
    print('navigator set up.')

    # set up criterions
    nll_crit = SeqModelCriterion().cuda()
    mse_crit = MaskedMSELoss().cuda()

    # set up optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 betas=(args.optim_alpha, args.optim_beta),
                                 eps=args.optim_epsilon,
                                 weight_decay=args.weight_decay)

    # resume from checkpoint
    infos = {}
    iters = infos.get('iters', 0)
    epoch = infos.get('epoch', 0)
    val_nll_history = infos.get('val_nll_history', {})
    val_mse_history = infos.get('val_mse_history', {})
    val_teacher_forcing_acc_history = infos.get(
        'val_teacher_forcing_acc_history', {})
    val_nav_object_nll_history = infos.get('val_nav_object_nll_history', {})
    val_nav_object_teacher_forcing_acc_history = infos.get(
        'val_nav_object_teacher_forcing_acc_history', {})
    val_nav_room_nll_history = infos.get('val_nav_room_nll_history', {})
    val_nav_room_teacher_forcing_acc_history = infos.get(
        'val_nav_room_teacher_forcing_acc_history', {})
    loss_history = infos.get('loss_history', {})
    nll_loss_history = infos.get('nll_loss_history', {})
    mse_loss_history = infos.get('mse_loss_history', {})
    lr = infos.get('lr', args.learning_rate)
    best_val_score, best_val_acc, best_predictions = None, None, None

    # start training
    while iters <= args.max_iters:
        print('Starting epoch %d' % epoch)
        # reset seq_length
        if args.use_curriculum:
            # assume we need 4 epochs to get full seq_length
            seq_length = min((args.max_seq_length // 4)**(epoch + 1),
                             args.max_seq_length)
            train_dataset.reset_seq_length(seq_length)
        else:
            seq_length = args.max_seq_length
        # train
        for batch in train_loader:
            # set mode
            model.train()
            # zero gradient
            optimizer.zero_grad()
            # batch = {qid, path_ix, house, id, type, phrase, phrase_emb, ego_feats, next_feats, res_feats,
            #  action_inputs, action_outputs, action_masks, ego_imgs}
            ego_feats = batch['ego_feats'].cuda()  # (n, L, 3200)
            phrase_embs = batch['phrase_emb'].cuda()  # (n, 300)
            action_inputs = batch['action_inputs'].cuda()  # (n, L)
            action_outputs = batch['action_outputs'].cuda()  # (n, L)
            action_masks = batch['action_masks'].cuda()  # (n, L)
            # forward
            # - logprobs (n, L, #actions)
            # - output_feats (n, L, rnn_size)
            # - pred_feats (n, L, 3200) or None
            logprobs, _, pred_feats, _ = model(ego_feats, phrase_embs,
                                               action_inputs)
            nll_loss = nll_crit(logprobs, action_outputs, action_masks)
            mse_loss = 0
            if args.use_next:
                next_feats = batch['next_feats'].cuda()  # (n, L, 3200)
                mse_loss = mse_crit(pred_feats, next_feats, action_masks)
            if args.use_residual:
                res_feats = batch['res_feats'].cuda()  # (n, L, 3200)
                mse_loss = mse_crit(pred_feats, res_feats, action_masks)
            loss = nll_loss + args.mse_weight * mse_loss
            # backward
            loss.backward()
            model_utils.clip_gradient(optimizer, args.grad_clip)
            optimizer.step()

            # training log
            if iters % args.losses_log_every == 0:
                loss_history[iters] = loss.item()
                nll_loss_history[iters] = nll_loss.item()
                mse_loss_history[iters] = mse_loss.item() if (
                    args.use_next or args.use_residual) else 0
                print('iters[%s]epoch[%s], train_loss=%.3f (nll_loss=%.3f, mse_loss=%.3f) lr=%.2E, cur_seq_length=%s' % \
                  (iters, epoch, loss_history[iters], nll_loss_history[iters], mse_loss_history[iters], lr, train_loader.dataset.cur_seq_length))

            # decay learning rate
            if args.learning_rate_decay_start > 0 and iters > args.learning_rate_decay_start:
                frac = (iters - args.learning_rate_decay_start
                        ) / args.learning_rate_decay_every
                decay_factor = 0.1**frac
                lr = args.learning_rate * decay_factor
                model_utils.set_lr(optimizer, lr)

            # evaluate
            if iters % args.save_checkpoint_every == 0:
                print('Checking validation ...')
                predictions, overall_nll, overall_teacher_forcing_acc, overall_mse, Nav_nll, Nav_teacher_forcing_acc = \
                  evaluate(val_dataset, model, nll_crit, mse_crit, opt)
                val_nll_history[iters] = overall_nll
                val_teacher_forcing_acc_history[
                    iters] = overall_teacher_forcing_acc
                val_mse_history[iters] = overall_mse
                val_nav_object_nll_history[iters] = Nav_nll['object']
                val_nav_object_teacher_forcing_acc_history[
                    iters] = Nav_teacher_forcing_acc['object']
                val_nav_room_nll_history[iters] = Nav_nll['room']
                val_nav_room_teacher_forcing_acc_history[
                    iters] = Nav_teacher_forcing_acc['room']

                # save model if best
                # consider all three accuracy, perhaps a better weighting is needed.
                current_score = -overall_nll
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_val_acc = overall_teacher_forcing_acc
                    best_predictions = predictions
                    checkpoint_path = osp.join(args.checkpoint_dir,
                                               '%s.pth' % args.id)
                    checkpoint = {}
                    checkpoint['model_state'] = model.state_dict()
                    checkpoint['opt'] = vars(args)
                    torch.save(checkpoint, checkpoint_path)
                    print('model saved to %s.' % checkpoint_path)

                # write to json report
                infos['iters'] = iters
                infos['epoch'] = epoch
                infos['loss_history'] = loss_history
                infos['nll_loss_history'] = nll_loss_history
                infos['mse_loss_history'] = mse_loss_history
                infos['val_nll_history'] = val_nll_history
                infos[
                    'val_teacher_forcing_acc_history'] = val_teacher_forcing_acc_history
                infos['val_mse_history'] = val_mse_history
                infos[
                    'val_nav_object_nll_history'] = val_nav_object_nll_history
                infos[
                    'val_nav_object_teacher_forcing_acc_history'] = val_nav_object_teacher_forcing_acc_history
                infos['val_nav_room_nll_history'] = val_nav_room_nll_history
                infos[
                    'val_nav_room_teacher_forcing_acc_history'] = val_nav_room_teacher_forcing_acc_history
                infos['best_val_score'] = best_val_score
                infos['best_val_acc'] = best_val_acc
                infos[
                    'best_predictions'] = predictions if best_predictions is None else best_predictions
                infos['opt'] = vars(args)
                infos['act_to_ix'] = train_dataset.act_to_ix
                infos_json = osp.join(args.checkpoint_dir, '%s.json' % args.id)
                with open(infos_json, 'w') as f:
                    json.dump(infos, f)
                print('infos saved to %s.' % infos_json)

            # update iters
            iters += 1

        # update epoch
        epoch += 1
Ejemplo n.º 2
0
def imitation(rank, args, shared_nav_model, counter):
    # set up tensorboard
    writer = tb.SummaryWriter(args.tb_dir, filename_suffix=str(rank))

    # set up cuda device
    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    # set up random seeds
    random.seed(args.seed + rank)
    np.random.randn(args.seed + rank)
    torch.manual_seed(args.seed + rank)

    # set up loader
    train_loader_kwargs = {
        'data_json': args.imitation_data_json,
        'data_h5': args.imitation_data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': 'train',
        'max_seq_length': args.max_seq_length,
        'requires_imgs': False,
        'nav_types': args.nav_types,
        'question_types': ['all'],
    }
    train_dataset = NavImitationDataset(**train_loader_kwargs)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=1)
    print('train_loader set up.')

    # set up optimizer on shared_nav_model
    # lr = 5e-5
    lr = args.learning_rate
    optimizer = torch.optim.Adam(shared_nav_model.parameters(),
                                 lr=lr,
                                 betas=(args.optim_alpha, args.optim_beta),
                                 eps=args.optim_epsilon,
                                 weight_decay=args.weight_decay)

    # set up models
    opt = vars(args)
    opt['act_to_ix'] = train_dataset.act_to_ix
    opt['num_actions'] = len(opt['act_to_ix'])
    model = Navigator(opt)
    model.cuda()
    print('navigator set up.')

    # set up criterions
    nll_crit = SeqModelCriterion().cuda()

    # -
    epoch = 0
    iters = 0

    # train
    while True:

        for batch in train_loader:
            # sync model
            model.load_state_dict(shared_nav_model.state_dict())
            model.train()
            model.cuda()

            # batch = {qid, path_ix, house, id, type, phrase, phrase_emb, ego_feats, next_feats, res_feats,
            #  action_inputs, action_outputs, action_masks, ego_imgs}
            ego_feats = batch['ego_feats'].cuda()  # (n, L, 3200)
            phrase_embs = batch['phrase_emb'].cuda()  # (n, 300)
            action_inputs = batch['action_inputs'].cuda()  # (n, L)
            action_outputs = batch['action_outputs'].cuda()  # (n, L)
            action_masks = batch['action_masks'].cuda()  # (n, L)
            # forward
            # - logprobs (n, L, #actions)
            # - output_feats (n, L, rnn_size)
            # - pred_feats (n, L, 3200) or None
            logprobs, _, pred_feats, _ = model(ego_feats, phrase_embs,
                                               action_inputs)
            nll_loss = nll_crit(logprobs, action_outputs, action_masks)

            # backward
            optimizer.zero_grad()
            nll_loss.backward()
            clip_model_gradient(model.parameters(), args.grad_clip)
            ensure_shared_grads(model.cpu(), shared_nav_model)
            optimizer.step()

            if iters % 25 == 0:
                print('imitation-r%s(ep%s it%s lr%.2E loss%.4f)' %
                      (rank, epoch, iters, lr, nll_loss))

            # write to tensorboard
            writer.add_scalar('imitation_rank/nll_loss',
                              float(nll_loss.item()), counter.value)

            # increate iters
            iters += 1

            # decay learning rate
            if args.lr_decay > 0:
                if args.im_learning_rate_decay_start > 0 and iters > args.im_learning_rate_decay_start:
                    frac = (iters - args.im_learning_rate_decay_start
                            ) / args.im_learning_rate_decay_every
                    decay_factor = 0.1**frac
                    lr = args.learning_rate * decay_factor
                    model_utils.set_lr(optimizer, lr)

        epoch += 1
Ejemplo n.º 3
0
def train(rank, args, shared_nav_model, counter, lock):
    # set up tensorboard
    # writer = tb.SummaryWriter(args.tb_dir, filename_suffix=str(rank))
    writer = tb.SummaryWriter(osp.join(args.tb_dir, str(rank)))

    # set up cuda device
    gpu_id = args.gpus.index(args.gpus[rank % len(args.gpus)])
    torch.cuda.set_device(gpu_id)

    # set up random seeds
    random.seed(args.seed + rank)
    np.random.randn(args.seed + rank)
    torch.manual_seed(args.seed + rank)

    # set up dataset
    cfg = {
        'colorFile':
        osp.join(args.house_meta_dir, 'colormap_fine.csv'),
        'roomTargetFile':
        osp.join(args.house_meta_dir, 'room_target_object_map.csv'),
        'modelCategoryFile':
        osp.join(args.house_meta_dir, 'ModelCategoryMapping.csv'),
        'prefix':
        args.house_data_dir,
    }
    loader_kwargs = {
        'data_json': args.data_json,
        'data_h5': args.data_h5,
        'path_feats_dir': args.path_feats_dir,
        'path_images_dir': args.path_images_dir,
        'split': 'train',
        'max_seq_length': args.max_seq_length,
        'nav_types': args.nav_types,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'cfg': cfg,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': 500,
        'pretrained_cnn_path': args.pretrained_cnn_path,
        'requires_imgs': False,
        'question_types': ['all'],
        'ratio': [rank / args.num_processes, (rank + 1) / args.num_processes]
    }
    dataset = NavReinforceDataset(**loader_kwargs)
    train_loader = DataLoader(dataset, batch_size=1, num_workers=0)
    print('train_loader set up.')

    # set up optimizer on shared_nav_model
    lr = args.learning_rate
    optimizer = torch.optim.Adam(shared_nav_model.parameters(),
                                 lr=lr,
                                 betas=(args.optim_alpha, args.optim_beta),
                                 eps=args.optim_epsilon,
                                 weight_decay=args.weight_decay)

    # set up model
    opt = vars(args)
    opt['act_to_ix'] = dataset.act_to_ix
    opt['num_actions'] = len(opt['act_to_ix'])
    model = Navigator(opt)
    print('navigator[%s] set up.' % rank)

    # set up metrics outside epoch, as we use running rewards through whole training
    nav_metrics = NavMetric(
        info={
            'split': 'train',
            'thread': rank
        },
        metric_names=['reward', 'episode_length'],
        log_json=osp.join(args.log_dir, 'nav_train_' + str(rank) + '.json'),
    )
    nav_metrics.update([0, 100])
    reward_list, episode_length_list = [], []
    rwd_stats = Statistics()  # computing running mean and std of rewards

    # path length multiplier
    min_dist = 5 if args.nav_types == [
        'object'
    ] else 15  # 15 for rl3 and rl5, 10 for rl4
    # max_dist = 25 if args.nav_types == ['object'] else 40
    max_dist = 35 if args.nav_types == ['object'] else 50
    mult = 0.1 if args.nav_types == ['object'] else 0.15
    rwd_thresh = 0.1 if args.nav_types == ['object'] else 0.00
    epoch = 0
    iters = 0

    # train
    while True:

        # reset envs
        train_loader.dataset._load_envs(start_idx=0, in_order=True)
        done = False  # current epoch is not done yet
        while not done:

            for batch in train_loader:
                # sync model
                model.load_state_dict(shared_nav_model.state_dict())
                model.train()
                model.cuda()

                # load target_paths from available_idx (of some envs)
                # batch = {idx, qid, path_ix, house, id, type, phrase, phrase_emb, ego_feats,
                # action_inputs, action_outputs, action_masks, action_length}
                idx = batch['idx'][0].item()
                qid = batch['qid'][0].item()
                phrase_emb = batch['phrase_emb'].cuda()  # (1, 300) float
                raw_ego_feats = batch['ego_feats']  # (1, L, 3200) float
                action_inputs = batch['action_inputs']  # (1, L) int
                action_outputs = batch['action_outputs']  # (1, L) int
                action_length = batch['action_length'][0].item()
                tgt_type = batch['type'][0]
                tgt_id = batch['id'][0]
                tgt_phrase = batch['phrase'][0]

                # to be recorded
                episode_length = 0
                episode_done = True
                dists_to_target = []
                pos_queue = []
                actions = []
                rewards = []
                nav_log_probs = []

                # spawn agent
                h3d = train_loader.dataset.episode_house
                if np.random.uniform(0, 1, 1)[0] <= args.shortest_path_ratio:
                    # half chance we use shortest path to spawn the agent (if vlen > 0, i.e., shortest path long enough)
                    use_shortest_path = True
                    vlen = min(max(min_dist, int(mult * action_length)),
                               action_length)
                    # forward throught navigator till spawn
                    if len(train_loader.dataset.episode_pos_queue) > vlen:
                        prev_pos_queue = train_loader.dataset.episode_pos_queue[:
                                                                                -vlen]  # till spawned position
                        ego_feats_pruned = raw_ego_feats[:, :len(
                            prev_pos_queue), :].cuda()  # (1, l, 3200)
                        action_inputs_pruned = action_inputs[:, :len(
                            prev_pos_queue)].cuda()  # (1, l)
                        _, _, _, state = model(
                            ego_feats_pruned, phrase_emb,
                            action_inputs_pruned)  # (1, l, rnn_size)
                        action = action_inputs[0, len(prev_pos_queue)].view(
                            -1).cuda()  # (1, )
                        init_pos = train_loader.dataset.episode_pos_queue[
                            -vlen]
                    else:
                        state = None
                        action = torch.LongTensor([
                            train_loader.dataset.act_to_ix['dummy']
                        ]).cuda()  # (1, )
                        init_pos = train_loader.dataset.episode_pos_queue[
                            0]  # use first position of the path
                else:
                    # half chance we randomly spawn agent
                    use_shortest_path = False
                    state = None
                    action = torch.LongTensor([
                        train_loader.dataset.act_to_ix['dummy']
                    ]).cuda()  # (1, )
                    init_pos, vlen = train_loader.dataset.spawn_agent(
                        min_dist, max_dist)
                    if init_pos is None:  # init_pos not found
                        continue

                # initiate
                h3d.env.reset(x=init_pos[0], y=init_pos[2], yaw=init_pos[3])
                init_dist_to_target = h3d.get_dist_to_target(h3d.env.cam.pos)
                if init_dist_to_target < 0:  # unreachable
                    continue
                dists_to_target += [init_dist_to_target]
                pos_queue += [init_pos]

                # act
                ego_img = h3d.env.render()
                ego_img = (
                    torch.from_numpy(ego_img.transpose(2, 0, 1)).float() /
                    255.).cuda()
                ego_feat = train_loader.dataset.cnn.extract_feats(
                    ego_img.unsqueeze(0), conv4_only=True)  # (1, 3200)
                prev_action, collision = None, False
                for step in range(args.max_episode_length):
                    # forward model one step
                    episode_length += 1
                    logprobs, state = model.forward_step(
                        ego_feat, phrase_emb, action,
                        state)  # (1, 4), (1, rnn_size)

                    # sample action
                    probs = torch.exp(logprobs)  # (1, 4)
                    action = probs.multinomial(
                        num_samples=1).detach()  # (1, 1)
                    if prev_action == 0 and collision and action[0][0].item(
                    ) == 0:
                        # special case: prev_action == "forward" && collision && cur_action == "forward"
                        # we sample from {'left', 'right', 'stop'} only
                        action = probs[0:1, 1:].multinomial(
                            num_samples=1).detach() + 1  # (1, 1)

                    if len(pos_queue) < min_dist:
                        # special case: our room navigator tends to stop early, let's push it to explore longer
                        action = probs[0:1, :3].multinomial(
                            num_samples=1).detach()  # (1, 1)

                    nav_log_probs.append(logprobs.gather(1, action))  # (1, 1)
                    action = action.view(-1)  # (1, )
                    actions.append(action)

                    # interact with environment
                    ego_img, reward, episode_done, collision = h3d.step(
                        action[0].item(), step_reward=True)
                    if not episode_done:
                        reward -= 0.01  # we don't wanna too long trajectory
                    episode_done = episode_done or episode_length >= args.max_episode_length  # no need actually
                    reward = max(min(reward, 1), -1)
                    rewards.append(reward)
                    prev_action = action[0].item()

                    # prepare state for next action
                    ego_img = (
                        torch.from_numpy(ego_img.transpose(2, 0, 1)).float() /
                        255.).cuda()
                    ego_feat = train_loader.dataset.cnn.extract_feats(
                        ego_img.unsqueeze(0), conv4_only=True)  # (1, 3200)

                    # add to result
                    dists_to_target.append(
                        h3d.get_dist_to_target(h3d.env.cam.pos))
                    pos_queue.append([
                        h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                        h3d.env.cam.pos.z, h3d.env.cam.yaw
                    ])

                    if episode_done:
                        break

                # final reward
                R = 0
                if tgt_type == 'object':
                    R = 1.0 if h3d.compute_target_iou(tgt_id) >= 0.1 else -1.0
                else:
                    R = 0.2 if h3d.is_inside_room(
                        pos_queue[-1],
                        train_loader.dataset.target_room) else -0.2
                    if R > 0 and h3d.compute_room_targets_iou(
                            room_to_objects[tgt_phrase]) > 0.1:
                        R += 1.0  # encourage agent to move to room-specific targets

                # backward
                nav_loss = 0
                new_rewards = [
                ]  # recording reshaped rewards, to be computed as moving average
                advantages = []
                for i in reversed(range(len(rewards))):
                    R = 0.99 * R + rewards[i]
                    new_rewards.insert(0, R)
                    rwd_stats.push(R)
                for nav_log_prob, R in zip(nav_log_probs, new_rewards):
                    advantage = (R - rwd_stats.mean()) / (
                        rwd_stats.stddev() + 1e-5)  # rl2, rl3
                    # advantage = R - rwd_stats.mean()  # rl1
                    nav_loss = nav_loss - nav_log_prob * advantage
                    advantages.insert(0, advantage)
                nav_loss /= max(1, len(nav_log_probs))

                optimizer.zero_grad()
                nav_loss.backward()
                clip_model_gradient(model.parameters(), args.grad_clip)
                # Till this point, we have grads on model's parameters!
                # but shared_nav_model's grads is still not updated
                ensure_shared_grads(model.cpu(), shared_nav_model)
                optimizer.step()

                if iters % 5 == 0:
                    log_info = 'train-r%2s(ep%2s it%5s lr%.2E mult%.2f, run_rwd%6.3f, bs_rwd%6.3f, bs_std%2.2f), vlen(%s)=%2s, rwd=%6.3f, ep_len=%3s, tgt:%s' % \
                      (rank, epoch, iters, lr, mult, nav_metrics.metrics[0][1], rwd_stats.mean(), rwd_stats.stddev(), 'short' if use_shortest_path else 'spawn', vlen, np.mean(new_rewards), len(new_rewards), tgt_phrase)
                    print(log_info)

                # update metrics
                reward_list += new_rewards
                episode_length_list.append(episode_length)
                if len(episode_length_list) > 50:
                    nav_metrics.update([reward_list, episode_length_list])
                    nav_metrics.dump_log()
                    reward_list, episode_length_list = [], []

                # write to tensorboard
                writer.add_scalar('train_rank%s/nav_loss' % rank,
                                  nav_loss.item(), counter.value)
                writer.add_scalar('train_rank%s/steps_avg_rwd' % rank,
                                  float(np.mean(new_rewards)), counter.value)
                writer.add_scalar('train_rank%s/steps_sum_rwd' % rank,
                                  float(np.sum(new_rewards)), counter.value)
                writer.add_scalar('train_rank%s/advantage' % rank,
                                  float(np.mean(advantages)), counter.value)

                # increase counter as this episode ends
                with lock:
                    counter.value += 1

                # increase mult
                iters += 1
                if nav_metrics.metrics[0][
                        1] > rwd_thresh:  # baseline = nav_metrics.metrics[0][1]
                    mult = min(mult + 0.1, 1.0)
                    rwd_thresh += 0.01
                else:
                    mult = max(mult - 0.1, 0.1)
                    rwd_thresh -= 0.01
                rwd_thresh = max(0.1, min(rwd_thresh, 0.2))

                # decay learning rate
                if args.learning_rate_decay_start > 0 and iters > args.learning_rate_decay_start:
                    frac = (iters - args.learning_rate_decay_start
                            ) / args.learning_rate_decay_every
                    decay_factor = 0.1**frac
                    lr = args.learning_rate * decay_factor
                    model_utils.set_lr(optimizer, lr)

            # next environments
            train_loader.dataset._load_envs(in_order=True)
            print("train_loader pruned_env_set len: {}".format(
                len(train_loader.dataset.pruned_env_set)))
            logging.info("train_loader pruned_env_set len: {}".format(
                len(train_loader.dataset.pruned_env_set)))
            if len(train_loader.dataset.pruned_env_set) == 0:
                done = True

        epoch += 1