Beispiel #1
0
    def batch_iter(self,
                   trajs,
                   batch_size,
                   shuffle=True,
                   epochs=-1) -> DictList:
        """
        :param trajs: A list of DictList
        :param batch_size: int
        :param seq_len: int
        :param epochs: int. If -1, then forever
        :return: DictList [bsz, seq_len]
        """
        epoch_iter = range(1, epochs + 1) if epochs > 0 else _forever()
        for _ in epoch_iter:
            if shuffle:
                random.shuffle(trajs)

            start_idx = 0
            while start_idx < len(trajs):
                batch = DictList()
                lengths = []
                task_lengths = []
                for _traj in trajs[start_idx:start_idx + batch_size]:
                    lengths.append(len(_traj.actions))
                    task_lengths.append(len(_traj.tasks))
                    _traj.apply(lambda _t: torch.tensor(_t))
                    batch.append(_traj)

                batch.apply(lambda _t: pad_sequence(_t, batch_first=True))
                yield batch, torch.tensor(lengths), torch.tensor(task_lengths)
                start_idx += batch_size
Beispiel #2
0
def batch_evaluate(envs: List, bot: ModelBot, cuda, verbose=False) -> List:
    """ Return trajectories after roll out """
    obs = DictList()
    for env in envs:
        obs.append(DictList(env.reset()))
    obs.apply(lambda _t: torch.tensor(_t).float())
    actives = torch.tensor([i for i in range(len(envs))])
    if cuda:
        obs.apply(lambda _t: _t.cuda())
        actives = actives.cuda()

    trajs = [DictList() for _ in range(len(envs))]
    sketchs = obs.sketch.long()[0]
    sketch_lengths = torch.tensor(sketchs.shape, device=sketchs.device)
    mems = bot.init_memory(
        sketchs.unsqueeze(0).repeat(len(actives), 1),
        sketch_lengths.repeat(len(actives))) if bot.is_recurrent else None

    # Continue roll out while at least one active
    steps = 0
    while len(actives) > 0:
        if verbose:
            print('active env:', len(actives))
        active_trajs = [trajs[i] for i in actives]
        with torch.no_grad():
            model_outputs = bot.get_action(
                obs.state,
                sketchs.unsqueeze(0).repeat(len(actives), 1),
                sketch_lengths.repeat(len(actives)), mems)
        actions = model_outputs.actions
        next_obs, rewards, dones = step_batch_envs(envs, actions, actives,
                                                   cuda)
        transition = DictList({'rewards': rewards})
        transition.update(obs)

        for idx, active_traj in enumerate(active_trajs):
            active_traj.append(transition[idx])
        steps += 1

        # Memory
        next_mems = None
        if bot.is_recurrent:
            next_mems = model_outputs.mems

        # For next step
        un_done_ids = (~dones).nonzero().squeeze(-1)
        obs = next_obs[un_done_ids]
        actives = actives[un_done_ids]
        mems = next_mems[un_done_ids] if next_mems is not None else None

    metric = DictList()
    for traj, env in zip(trajs, envs):
        traj.apply(lambda _tensors: torch.stack(_tensors))
        metric.append({
            'ret': sum(env.local_score),
            'succs': traj.rewards.sum().item(),
            'length': len(traj.rewards)
        })
    metric.apply(lambda _t: np.mean(_t))
    return metric
Beispiel #3
0
def step_batch_envs(envs, actions, actives, cuda):
    """ Step a batch of envs. And detect if there are inactive/done envs
    return obss, rewards, dones of the active envs
    """
    assert actions.shape[0] == len(actives)
    active_envs = [envs[i] for i in actives]
    obss = DictList()
    rewards = []
    dones = []
    for action, env in zip(actions, active_envs):
        obs, reward, done, _ = env.step(action.cpu().numpy())
        obss.append(obs)
        rewards.append(reward)
        dones.append(done)

    obss.apply(lambda _t: torch.tensor(_t).float())
    rewards = torch.tensor(rewards).float()
    dones = torch.tensor(dones)

    if cuda:
        obss.apply(lambda _t: _t.cuda())
        rewards = rewards.cuda()
        dones = dones.cuda()

    # Update active
    return obss, rewards, dones
Beispiel #4
0
def evaluate_on_envs(bot, dataloader):
    val_metrics = {}
    bot.eval()
    envs = dataloader.env_names
    for env_name in envs:
        val_iter = dataloader.val_iter(batch_size=FLAGS.il_batch_size,
                                       env_names=[env_name],
                                       shuffle=True)
        output = DictList({})
        total_lengths = 0
        for batch, batch_lens, batch_sketch_lens in val_iter:
            if FLAGS.cuda:
                batch.apply(lambda _t: _t.cuda())
                batch_lens = batch_lens.cuda()
                batch_sketch_lens = batch_sketch_lens.cuda()

            # Initialize memory
            with torch.no_grad():
                #batch_results = run_batch(batch, batch_lens, batch_sketch_lens, bot, mode='val')
                start = time.time()
                batch_results, _ = bot.teacherforcing_batch(
                    batch,
                    batch_lens,
                    batch_sketch_lens,
                    recurrence=FLAGS.il_recurrence)
                end = time.time()
                print('batch time', end - start)
            batch_results.apply(lambda _t: _t.sum().item())
            output.append(batch_results)
            total_lengths += batch_lens.sum().item()
            if FLAGS.debug:
                break
        output.apply(lambda _t: torch.tensor(_t).sum().item() / total_lengths)
        val_metrics[env_name] = {k: v for k, v in output.items()}

    # Parsing
    if 'om' in FLAGS.arch:
        with torch.no_grad():
            parsing_stats, parsing_lines = parsing_loop(
                bot,
                dataloader=dataloader,
                batch_size=FLAGS.il_batch_size,
                cuda=FLAGS.cuda)
        for env_name in parsing_stats:
            parsing_stats[env_name].apply(lambda _t: np.mean(_t))
            val_metrics[env_name].update(parsing_stats[env_name])
        logging.info('Get parsing result')
        logging.info('\n' + '\n'.join(parsing_lines))

    # evaluate on free run env
    #if not FLAGS.debug:
    #    for sketch_length in val_metrics:
    #        envs = [gym.make('jacopinpad-v0', sketch_length=sketch_length,
    #                         max_steps_per_sketch=FLAGS.max_steps_per_sketch)
    #                for _ in range(FLAGS.eval_episodes)]
    #        with torch.no_grad():
    #            free_run_metric = batch_evaluate(envs=envs, bot=bot, cuda=FLAGS.cuda)
    #        val_metrics[sketch_length].update(free_run_metric)
    return val_metrics
Beispiel #5
0
    def teacherforcing_batch(self, batch: DictList, batch_lengths,
                             sketch_lengths,
                             recurrence) -> (DictList, DictList):
        """
        :param batch: DictList object [bsz, seqlen]
        :param batch_lengths: [bsz]
        :param sketch_lengths: [bsz]
        :param recurrence: an int
        :return:
            stats: A DictList of bsz, mem_size
            extra_info: A DictList of extra info
        """
        bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1]
        sketchs = batch.tasks
        final_outputs = DictList({})
        extra_info = DictList({})
        mems = None
        if self.is_recurrent:
            mems = self.init_memory(sketchs, sketch_lengths)

        for t in range(seqlen):
            final_output = DictList({})
            model_output = self.forward(batch.states[:, t], sketchs,
                                        sketch_lengths, mems)
            logprobs = model_output.dist.log_prob(batch.actions[:, t].float())
            if 'log_end' in model_output:
                # p_end + (1 - pend) action_prob
                log_no_end_term = model_output.log_no_end + logprobs
                logprobs = torch.logsumexp(torch.stack(
                    [model_output.log_end, log_no_end_term], dim=-1),
                                           dim=-1)
                final_output.log_end = model_output.log_end
            final_output.logprobs = logprobs
            if 'p' in model_output:
                extra_info.append({'p': model_output.p})
            final_outputs.append(final_output)

            # Update memory
            next_mems = None
            if self.is_recurrent:
                next_mems = model_output.mems
                if (t + 1) % recurrence == 0:
                    next_mems = next_mems.detach()
            mems = next_mems

        # Stack on time dim
        final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1))
        extra_info.apply(lambda _tensors: torch.stack(_tensors, dim=1))
        sequence_mask = torch.arange(
            batch_lengths.max().item(),
            device=batch_lengths.device)[None, :] < batch_lengths[:, None]
        final_outputs.loss = -final_outputs.logprobs
        if 'log_end' in final_outputs:
            batch_ids = torch.arange(bsz, device=batch.states.device)
            final_outputs.loss[batch_ids, batch_lengths -
                               1] = final_outputs.log_end[batch_ids,
                                                          batch_lengths - 1]
        final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.))
        return final_outputs, extra_info
Beispiel #6
0
def run_batch(batch: DictList,
              batch_lengths,
              sketch_lengths,
              bot: ModelBot,
              mode='train') \
        -> (DictList, torch.Tensor):
    """
    :param batch: DictList object [bsz, seqlen]
    :param bot:  A model Bot
    :param mode: 'train' or 'eval'
    :return:
        stats: A DictList of bsz, mem_size
    """
    bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1]
    sketchs = batch.tasks
    final_outputs = DictList({})
    mems = None
    if bot.is_recurrent:
        mems = bot.init_memory(sketchs, sketch_lengths)

    for t in range(seqlen):
        final_output = DictList({})
        model_output = bot.forward(batch.states[:, t], sketchs, sketch_lengths,
                                   mems)
        logprobs = model_output.dist.log_prob(batch.actions[:, t].float())
        if 'log_end' in model_output:
            # p_end + (1 - pend) action_prob
            log_no_end_term = model_output.log_no_end + logprobs
            logprobs = torch.logsumexp(torch.stack(
                [model_output.log_end, log_no_end_term], dim=-1),
                                       dim=-1)
            final_output.log_end = model_output.log_end
        final_output.logprobs = logprobs
        final_outputs.append(final_output)

        # Update memory
        next_mems = None
        if bot.is_recurrent:
            next_mems = model_output.mems
            if (t + 1) % FLAGS.il_recurrence == 0 and mode == 'train':
                next_mems = next_mems.detach()
        mems = next_mems

    # Stack on time dim
    final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1))
    sequence_mask = torch.arange(
        batch_lengths.max().item(),
        device=batch_lengths.device)[None, :] < batch_lengths[:, None]
    final_outputs.loss = -final_outputs.logprobs
    if 'log_end' in final_outputs:
        batch_ids = torch.arange(bsz, device=batch.states.device)
        final_outputs.loss[batch_ids, batch_lengths -
                           1] = final_outputs.log_end[batch_ids,
                                                      batch_lengths - 1]
    final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.))
    return final_outputs
Beispiel #7
0
def evaluate_on_env(modular_p: ModularPolicy,
                    sketch_length,
                    max_steps_per_sketch,
                    use_sketch_id=False):
    start = time.time()
    env = gym.make('jacopinpad-v0',
                   sketch_length=sketch_length,
                   max_steps_per_sketch=max_steps_per_sketch)
    device = next(modular_p.parameters()).device
    modular_p.eval()
    obs = DictList(env.reset())
    modular_p.reset(subtasks=obs.sketch)
    obs.apply(lambda _t: torch.tensor(_t, device=device).float())
    done = False
    traj = DictList()
    try:
        while not done:
            if not use_sketch_id:
                action = modular_p.get_action(obs.state.unsqueeze(0))
            else:
                action = modular_p.get_action(obs.state.unsqueeze(0),
                                              sketch_idx=int(
                                                  obs.sketch_idx.item()))
            if action is not None:
                next_obs, reward, done, _ = env.step(action.cpu().numpy()[0])
                transition = {
                    'reward': reward,
                    'action': action,
                    'features': obs.state
                }
                traj.append(transition)

                obs = DictList(next_obs)
                obs.apply(lambda _t: torch.tensor(_t, device=device).float())
            else:
                done = True
    except MujocoException:
        pass
    end = time.time()
    if 'reward' in traj:
        return {
            'succs': np.sum(traj.reward),
            'episode_length': len(traj.reward),
            'ret': sum(env.local_score),
            'runtime': end - start
        }
    else:
        return {
            'succs': 0,
            'episode_length': 0,
            'ret': 0,
            'runtime': end - start
        }
Beispiel #8
0
def teacherforce_batch(modular_p: ModularPolicy,
                       trajs: DictList,
                       lengths,
                       subtask_lengths,
                       dropout_p,
                       decode=False):
    """ Return log probs of a trajectory """
    dropout_p = 0. if decode else dropout_p
    unique_tasks = set()
    for subtask in trajs.tasks:
        for task_id in subtask:
            task_id = task_id.item()
            if not task_id in unique_tasks:
                unique_tasks.add(task_id)
    unique_tasks = list(unique_tasks)

    # Forward for all unique task
    # task_results [bsz, length, all_tasks]
    states = trajs.states.float()
    targets = trajs.actions.float()
    all_task_results = DictList()
    for task in unique_tasks:
        all_task_results.append(
            modular_p.forward(task, states, targets, dropout_p=dropout_p))
    all_task_results.apply(lambda _t: torch.stack(_t, dim=2))

    # pad subtasks
    subtasks = trajs.tasks

    # results [bsz, len, nb_tasks]
    results = DictList()
    for batch_id, subtask in enumerate(subtasks):
        curr_result = DictList()
        for task in subtask:
            task_id = unique_tasks.index(task)
            curr_result.append(all_task_results[batch_id, :, task_id])

        # [len, tasks]
        curr_result.apply(lambda _t: torch.stack(_t, dim=1))
        results.append(curr_result)
    results.apply(lambda _t: torch.stack(_t, dim=0))

    # Training
    if not decode:
        log_alphas = tac_forward_log(action_logprobs=results.action_logprobs,
                                     stop_logprobs=results.stop_logprobs,
                                     lengths=lengths,
                                     subtask_lengths=subtask_lengths)
        seq_logprobs = log_alphas[
            torch.arange(log_alphas.shape[0], device=log_alphas.device),
            lengths - 1, subtask_lengths - 1]
        avg_logprobs = seq_logprobs.sum() / lengths.sum()
        return {'loss': -avg_logprobs}

    # Decode
    else:
        alphas, _ = tac_forward(action_logprobs=results.action_logprobs,
                                stop_logprobs=results.stop_logprobs,
                                lengths=lengths,
                                subtask_lengths=subtask_lengths)
        decoded = alphas.argmax(-1)
        batch_ids = torch.arange(decoded.shape[0],
                                 device=decoded.device).unsqueeze(-1).repeat(
                                     1, decoded.shape[1])
        decoded_subtasks = subtasks[batch_ids, decoded]
        total_task_corrects = 0
        for idx, (subtask, decoded_subtask, action, length, gt) in enumerate(
                zip(subtasks, decoded_subtasks, trajs.actions, lengths,
                    trajs.gt_onsets)):
            _decoded_subtask = decoded_subtask[:length]
            _action = action[:length]
            gt = gt[:length]
            total_task_corrects += (gt == _decoded_subtask).float().sum()
        return {
            'task_acc': total_task_corrects / lengths.sum()
        }, {
            'tru': gt,
            'act': _action,
            'dec': _decoded_subtask
        }
Beispiel #9
0
def main(training_folder):
    logging.info('start taco...')
    dataloader = Dataloader(FLAGS.sketch_lengths, 0.2)
    model = ModularPolicy(nb_subtasks=10, input_dim=39,
                          n_actions=9,
                          a_mu=dataloader.a_mu,
                          a_std=dataloader.a_std,
                          s_mu=dataloader.s_mu,
                          s_std=dataloader.s_std)
    if FLAGS.cuda:
        model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.taco_lr)

    train_steps = 0
    writer = SummaryWriter(training_folder)
    train_iter = dataloader.train_iter(batch_size=FLAGS.taco_batch_size)
    nb_frames = 0
    curr_best = np.inf
    train_stats = DictList()

    # test dataloader
    test_sketch_lengths = set(FLAGS.test_sketch_lengths) - set(FLAGS.sketch_lengths)
    test_dataloader = None if len(test_sketch_lengths) == 0 else Dataloader(test_sketch_lengths, FLAGS.il_val_ratio)
    scheduler = DropoutScheduler()
    while True:
        if train_steps > FLAGS.taco_train_steps:
            logging.info('Reaching maximum steps')
            break

        if train_steps % FLAGS.taco_eval_freq == 0:
            val_metrics = evaluate_loop(dataloader, model, dropout_p=scheduler.dropout_p)
            logging_metrics(nb_frames, train_steps, val_metrics, writer, 'val')

            if test_dataloader is not None:
                test_metrics = evaluate_loop(test_dataloader, model, dropout_p=scheduler.dropout_p)
                logging_metrics(nb_frames, train_steps, test_metrics, writer, 'test')

            avg_loss = [val_metrics[env_name].loss for env_name in val_metrics]
            avg_loss = np.mean(avg_loss)
            if avg_loss < curr_best:
                curr_best = avg_loss
                logging.info('Save Best with loss: {}'.format(avg_loss))
                # Save the checkpoint
                with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f:
                    torch.save(model, f)

        model.train()
        train_batch, train_lengths, train_subtask_lengths = train_iter.__next__()
        if FLAGS.cuda:
            train_batch.apply(lambda _t: _t.cuda())
            train_lengths = train_lengths.cuda()
            train_subtask_lengths = train_subtask_lengths.cuda()
        start = time.time()
        train_outputs = teacherforce_batch(modular_p=model,
                                           trajs=train_batch,
                                           lengths=train_lengths,
                                           subtask_lengths=train_subtask_lengths,
                                           decode=False,
                                           dropout_p=scheduler.dropout_p)
        optimizer.zero_grad()
        train_outputs['loss'].backward()
        optimizer.step()
        train_steps += 1
        scheduler.step()
        nb_frames += train_lengths.sum().item()
        end = time.time()
        fps = train_lengths.sum().item() / (end - start)
        train_outputs['fps'] = torch.tensor(fps)

        train_outputs = DictList(train_outputs)
        train_outputs.apply(lambda _t: _t.item())
        train_stats.append(train_outputs)

        if train_steps % FLAGS.taco_eval_freq == 0:
            train_stats.apply(lambda _tensors: np.mean(_tensors))
            logger_str = ['[TRAIN] steps={}'.format(train_steps)]
            for k, v in train_stats.items():
                logger_str.append("{}: {:.4f}".format(k, v))
                writer.add_scalar('train/' + k, v, global_step=nb_frames)
            logging.info('\t'.join(logger_str))
            train_stats = DictList()
            writer.flush()
Beispiel #10
0
def main_loop(bot, dataloader, opt, training_folder, test_dataloader=None):
    # Prepare
    train_steps = 0
    writer = SummaryWriter(training_folder)
    train_iter = dataloader.train_iter(batch_size=FLAGS.il_batch_size)
    nb_frames = 0
    train_stats = DictList()
    curr_best = 100000
    while True:
        if train_steps > FLAGS.il_train_steps:
            logging.info('Reaching maximum steps')
            break

        if train_steps % FLAGS.il_save_freq == 0:
            with open(
                    os.path.join(training_folder,
                                 'bot{}.pkl'.format(train_steps)), 'wb') as f:
                torch.save(bot, f)

        if train_steps % FLAGS.il_eval_freq == 0:
            # testing on valid
            val_metrics = evaluate_on_envs(bot, dataloader)
            logging_metrics(nb_frames,
                            train_steps,
                            val_metrics,
                            writer,
                            prefix='val')

            # testing on test env
            if test_dataloader is not None:
                test_metrics = evaluate_on_envs(bot, test_dataloader)
                logging_metrics(nb_frames,
                                train_steps,
                                test_metrics,
                                writer,
                                prefix='test')

            avg_loss = [
                val_metrics[env_name]['loss'] for env_name in val_metrics
            ]
            avg_loss = np.mean(avg_loss)

            if avg_loss < curr_best:
                curr_best = avg_loss
                logging.info('Save Best with loss: {}'.format(avg_loss))

                # Save the checkpoint
                with open(os.path.join(training_folder, 'bot_best.pkl'),
                          'wb') as f:
                    torch.save(bot, f)

        # Forward/Backward
        bot.train()
        train_batch, train_lengths, train_sketch_lengths = train_iter.__next__(
        )
        if FLAGS.cuda:
            train_batch.apply(lambda _t: _t.cuda())
            train_lengths = train_lengths.cuda()
            train_sketch_lengths = train_sketch_lengths.cuda()

        start = time.time()
        #train_batch_res = run_batch(train_batch, train_lengths, train_sketch_lengths, bot)
        train_batch_res, _ = bot.teacherforcing_batch(
            train_batch,
            train_lengths,
            train_sketch_lengths,
            recurrence=FLAGS.il_recurrence)
        train_batch_res.apply(lambda _t: _t.sum() / train_lengths.sum())
        batch_time = time.time() - start
        loss = train_batch_res.loss
        opt.zero_grad()
        loss.backward()
        params = [p for p in bot.parameters() if p.requires_grad]
        grad_norm = torch.nn.utils.clip_grad_norm_(parameters=params,
                                                   max_norm=FLAGS.il_clip)
        opt.step()
        train_steps += 1
        nb_frames += train_lengths.sum().item()
        fps = train_lengths.sum().item() / batch_time

        stats = DictList()
        stats.grad_norm = grad_norm
        stats.loss = train_batch_res.loss.detach()
        stats.fps = torch.tensor(fps)
        train_stats.append(stats)

        if train_steps % FLAGS.il_eval_freq == 0:
            train_stats.apply(
                lambda _tensors: torch.stack(_tensors).mean().item())
            logger_str = ['[TRAIN] steps={}'.format(train_steps)]
            for k, v in train_stats.items():
                logger_str.append("{}: {:.4f}".format(k, v))
                writer.add_scalar('train/' + k, v, global_step=nb_frames)
            logging.info('\t'.join(logger_str))
            train_stats = DictList()
            writer.flush()
def main(training_folder):
    logging.info('Start compile...')
    dataloader = Dataloader(FLAGS.sketch_lengths, 0.2)
    model = compile.CompILE(vec_size=39,
                            hidden_size=FLAGS.hidden_size,
                            action_size=9,
                            env_arch=FLAGS.env_arch,
                            max_num_segments=FLAGS.compile_max_segs,
                            latent_dist=FLAGS.compile_latent,
                            beta_b=FLAGS.compile_beta_b,
                            beta_z=FLAGS.compile_beta_z,
                            prior_rate=FLAGS.compile_prior_rate,
                            dataloader=dataloader)
    if FLAGS.cuda:
        model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.compile_lr)

    train_steps = 0
    writer = SummaryWriter(training_folder)
    train_iter = dataloader.train_iter(batch_size=FLAGS.compile_batch_size)
    nb_frames = 0
    curr_best = np.inf
    train_stats = DictList()
    while True:
        if train_steps > FLAGS.compile_train_steps:
            logging.info('Reaching maximum steps')
            break

        if train_steps % FLAGS.compile_eval_freq == 0:
            # Testing
            val_metrics = {}
            model.eval()
            for env_name in FLAGS.sketch_lengths:
                val_metrics[env_name] = DictList()
                val_iter = dataloader.val_iter(
                    batch_size=FLAGS.compile_batch_size, env_names=[env_name])
                for val_batch, val_lengths, val_sketch_lens in val_iter:
                    if FLAGS.cuda:
                        val_batch.apply(lambda _t: _t.cuda())
                        val_lengths = val_lengths.cuda()
                        val_sketch_lens = val_sketch_lens.cuda()
                    with torch.no_grad():
                        val_outputs, extra_info = model.forward(
                            val_batch, val_lengths, val_sketch_lens)
                    val_metrics[env_name].append(val_outputs)

                # Parsing
                total_lengths = 0
                total_task_corrects = 0
                val_iter = dataloader.val_iter(batch_size=FLAGS.eval_episodes,
                                               env_names=[env_name],
                                               shuffle=True)
                val_batch, val_lengths, val_sketch_lens = val_iter.__next__()
                if FLAGS.cuda:
                    val_batch.apply(lambda _t: _t.cuda())
                    val_lengths = val_lengths.cuda()
                    val_sketch_lens = val_sketch_lens.cuda()
                with torch.no_grad():
                    val_outputs, extra_info = model.forward(
                        val_batch, val_lengths, val_sketch_lens)
                seg = torch.stack(extra_info['segment'], dim=1).argmax(-1)
                for batch_id, (length, sketch_length, _seg) in enumerate(
                        zip(val_lengths, val_sketch_lens, seg)):
                    traj = val_batch[batch_id]
                    traj = traj[:length]
                    _gt_subtask = traj.gt_onsets
                    target = point_of_change(_gt_subtask)
                    _seg = _seg[_seg.sort()[1]].cpu().tolist()

                    # Remove the last one because too trivial
                    val_metrics[env_name].append({
                        'f1_tol0': f1(target, _seg, 0),
                        'f1_tol1': f1(target, _seg, 1),
                        'f1_tol2': f1(target, _seg, 2)
                    })

                    # subtask
                    total_lengths += length.item()
                    _decoded_subtask = get_subtask_seq(
                        length.item(),
                        subtask=traj.tasks.tolist(),
                        use_ids=np.array(_seg))
                    total_task_corrects += (_gt_subtask.cpu(
                    ) == _decoded_subtask.cpu()).float().sum()

                # record task acc
                val_metrics[
                    env_name].task_acc = total_task_corrects / total_lengths

                # Print parsing result
                lines = []
                lines.append('tru_ids: {}'.format(target))
                lines.append('dec_ids: {}'.format(_seg))
                logging.info('\n'.join(lines))
                val_metrics[env_name].apply(
                    lambda _t: torch.tensor(_t).float().mean().item())

            # Logger
            for env_name, metric in val_metrics.items():
                line = ['[VALID][{}] steps={}'.format(env_name, train_steps)]
                for k, v in metric.items():
                    line.append('{}: {:.4f}'.format(k, v))
                logging.info('\t'.join(line))

            mean_val_metric = DictList()
            for metric in val_metrics.values():
                mean_val_metric.append(metric)
            mean_val_metric.apply(lambda t: torch.mean(torch.tensor(t)))
            for k, v in mean_val_metric.items():
                writer.add_scalar('val/' + k, v.item(), nb_frames)
            writer.flush()

            avg_loss = [val_metrics[env_name].loss for env_name in val_metrics]
            avg_loss = np.mean(avg_loss)
            if avg_loss < curr_best:
                curr_best = avg_loss
                logging.info('Save Best with loss: {}'.format(avg_loss))
                # Save the checkpoint
                with open(os.path.join(training_folder, 'bot_best.pkl'),
                          'wb') as f:
                    torch.save(model, f)

        model.train()
        train_batch, train_lengths, train_sketch_lens = train_iter.__next__()
        if FLAGS.cuda:
            train_batch.apply(lambda _t: _t.cuda())
            train_lengths = train_lengths.cuda()
            train_sketch_lens = train_sketch_lens.cuda()
        train_outputs, _ = model.forward(train_batch, train_lengths,
                                         train_sketch_lens)

        optimizer.zero_grad()
        train_outputs['loss'].backward()
        optimizer.step()
        train_steps += 1
        nb_frames += train_lengths.sum().item()

        train_outputs = DictList(train_outputs)
        train_outputs.apply(lambda _t: _t.item())
        train_stats.append(train_outputs)

        if train_steps % FLAGS.compile_eval_freq == 0:
            train_stats.apply(lambda _tensors: np.mean(_tensors))
            logger_str = ['[TRAIN] steps={}'.format(train_steps)]
            for k, v in train_stats.items():
                logger_str.append("{}: {:.4f}".format(k, v))
                writer.add_scalar('train/' + k, v, global_step=nb_frames)
            logging.info('\t'.join(logger_str))
            train_stats = DictList()
            writer.flush()