예제 #1
0
def evaluate_on_envs(bot, dataloader):
    val_metrics = {}
    bot.eval()
    envs = dataloader.env_names
    for env_name in envs:
        val_iter = dataloader.val_iter(batch_size=FLAGS.il_batch_size,
                                       env_names=[env_name],
                                       shuffle=True)
        output = DictList({})
        total_lengths = 0
        for batch, batch_lens, batch_sketch_lens in val_iter:
            if FLAGS.cuda:
                batch.apply(lambda _t: _t.cuda())
                batch_lens = batch_lens.cuda()
                batch_sketch_lens = batch_sketch_lens.cuda()

            # Initialize memory
            with torch.no_grad():
                #batch_results = run_batch(batch, batch_lens, batch_sketch_lens, bot, mode='val')
                start = time.time()
                batch_results, _ = bot.teacherforcing_batch(
                    batch,
                    batch_lens,
                    batch_sketch_lens,
                    recurrence=FLAGS.il_recurrence)
                end = time.time()
                print('batch time', end - start)
            batch_results.apply(lambda _t: _t.sum().item())
            output.append(batch_results)
            total_lengths += batch_lens.sum().item()
            if FLAGS.debug:
                break
        output.apply(lambda _t: torch.tensor(_t).sum().item() / total_lengths)
        val_metrics[env_name] = {k: v for k, v in output.items()}

    # Parsing
    if 'om' in FLAGS.arch:
        with torch.no_grad():
            parsing_stats, parsing_lines = parsing_loop(
                bot,
                dataloader=dataloader,
                batch_size=FLAGS.il_batch_size,
                cuda=FLAGS.cuda)
        for env_name in parsing_stats:
            parsing_stats[env_name].apply(lambda _t: np.mean(_t))
            val_metrics[env_name].update(parsing_stats[env_name])
        logging.info('Get parsing result')
        logging.info('\n' + '\n'.join(parsing_lines))

    # evaluate on free run env
    #if not FLAGS.debug:
    #    for sketch_length in val_metrics:
    #        envs = [gym.make('jacopinpad-v0', sketch_length=sketch_length,
    #                         max_steps_per_sketch=FLAGS.max_steps_per_sketch)
    #                for _ in range(FLAGS.eval_episodes)]
    #        with torch.no_grad():
    #            free_run_metric = batch_evaluate(envs=envs, bot=bot, cuda=FLAGS.cuda)
    #        val_metrics[sketch_length].update(free_run_metric)
    return val_metrics
예제 #2
0
def main(training_folder):
    logging.info('start taco...')
    dataloader = Dataloader(FLAGS.sketch_lengths, 0.2)
    model = ModularPolicy(nb_subtasks=10, input_dim=39,
                          n_actions=9,
                          a_mu=dataloader.a_mu,
                          a_std=dataloader.a_std,
                          s_mu=dataloader.s_mu,
                          s_std=dataloader.s_std)
    if FLAGS.cuda:
        model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.taco_lr)

    train_steps = 0
    writer = SummaryWriter(training_folder)
    train_iter = dataloader.train_iter(batch_size=FLAGS.taco_batch_size)
    nb_frames = 0
    curr_best = np.inf
    train_stats = DictList()

    # test dataloader
    test_sketch_lengths = set(FLAGS.test_sketch_lengths) - set(FLAGS.sketch_lengths)
    test_dataloader = None if len(test_sketch_lengths) == 0 else Dataloader(test_sketch_lengths, FLAGS.il_val_ratio)
    scheduler = DropoutScheduler()
    while True:
        if train_steps > FLAGS.taco_train_steps:
            logging.info('Reaching maximum steps')
            break

        if train_steps % FLAGS.taco_eval_freq == 0:
            val_metrics = evaluate_loop(dataloader, model, dropout_p=scheduler.dropout_p)
            logging_metrics(nb_frames, train_steps, val_metrics, writer, 'val')

            if test_dataloader is not None:
                test_metrics = evaluate_loop(test_dataloader, model, dropout_p=scheduler.dropout_p)
                logging_metrics(nb_frames, train_steps, test_metrics, writer, 'test')

            avg_loss = [val_metrics[env_name].loss for env_name in val_metrics]
            avg_loss = np.mean(avg_loss)
            if avg_loss < curr_best:
                curr_best = avg_loss
                logging.info('Save Best with loss: {}'.format(avg_loss))
                # Save the checkpoint
                with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f:
                    torch.save(model, f)

        model.train()
        train_batch, train_lengths, train_subtask_lengths = train_iter.__next__()
        if FLAGS.cuda:
            train_batch.apply(lambda _t: _t.cuda())
            train_lengths = train_lengths.cuda()
            train_subtask_lengths = train_subtask_lengths.cuda()
        start = time.time()
        train_outputs = teacherforce_batch(modular_p=model,
                                           trajs=train_batch,
                                           lengths=train_lengths,
                                           subtask_lengths=train_subtask_lengths,
                                           decode=False,
                                           dropout_p=scheduler.dropout_p)
        optimizer.zero_grad()
        train_outputs['loss'].backward()
        optimizer.step()
        train_steps += 1
        scheduler.step()
        nb_frames += train_lengths.sum().item()
        end = time.time()
        fps = train_lengths.sum().item() / (end - start)
        train_outputs['fps'] = torch.tensor(fps)

        train_outputs = DictList(train_outputs)
        train_outputs.apply(lambda _t: _t.item())
        train_stats.append(train_outputs)

        if train_steps % FLAGS.taco_eval_freq == 0:
            train_stats.apply(lambda _tensors: np.mean(_tensors))
            logger_str = ['[TRAIN] steps={}'.format(train_steps)]
            for k, v in train_stats.items():
                logger_str.append("{}: {:.4f}".format(k, v))
                writer.add_scalar('train/' + k, v, global_step=nb_frames)
            logging.info('\t'.join(logger_str))
            train_stats = DictList()
            writer.flush()
예제 #3
0
def main_loop(bot, dataloader, opt, training_folder, test_dataloader=None):
    # Prepare
    train_steps = 0
    writer = SummaryWriter(training_folder)
    train_iter = dataloader.train_iter(batch_size=FLAGS.il_batch_size)
    nb_frames = 0
    train_stats = DictList()
    curr_best = 100000
    while True:
        if train_steps > FLAGS.il_train_steps:
            logging.info('Reaching maximum steps')
            break

        if train_steps % FLAGS.il_save_freq == 0:
            with open(
                    os.path.join(training_folder,
                                 'bot{}.pkl'.format(train_steps)), 'wb') as f:
                torch.save(bot, f)

        if train_steps % FLAGS.il_eval_freq == 0:
            # testing on valid
            val_metrics = evaluate_on_envs(bot, dataloader)
            logging_metrics(nb_frames,
                            train_steps,
                            val_metrics,
                            writer,
                            prefix='val')

            # testing on test env
            if test_dataloader is not None:
                test_metrics = evaluate_on_envs(bot, test_dataloader)
                logging_metrics(nb_frames,
                                train_steps,
                                test_metrics,
                                writer,
                                prefix='test')

            avg_loss = [
                val_metrics[env_name]['loss'] for env_name in val_metrics
            ]
            avg_loss = np.mean(avg_loss)

            if avg_loss < curr_best:
                curr_best = avg_loss
                logging.info('Save Best with loss: {}'.format(avg_loss))

                # Save the checkpoint
                with open(os.path.join(training_folder, 'bot_best.pkl'),
                          'wb') as f:
                    torch.save(bot, f)

        # Forward/Backward
        bot.train()
        train_batch, train_lengths, train_sketch_lengths = train_iter.__next__(
        )
        if FLAGS.cuda:
            train_batch.apply(lambda _t: _t.cuda())
            train_lengths = train_lengths.cuda()
            train_sketch_lengths = train_sketch_lengths.cuda()

        start = time.time()
        #train_batch_res = run_batch(train_batch, train_lengths, train_sketch_lengths, bot)
        train_batch_res, _ = bot.teacherforcing_batch(
            train_batch,
            train_lengths,
            train_sketch_lengths,
            recurrence=FLAGS.il_recurrence)
        train_batch_res.apply(lambda _t: _t.sum() / train_lengths.sum())
        batch_time = time.time() - start
        loss = train_batch_res.loss
        opt.zero_grad()
        loss.backward()
        params = [p for p in bot.parameters() if p.requires_grad]
        grad_norm = torch.nn.utils.clip_grad_norm_(parameters=params,
                                                   max_norm=FLAGS.il_clip)
        opt.step()
        train_steps += 1
        nb_frames += train_lengths.sum().item()
        fps = train_lengths.sum().item() / batch_time

        stats = DictList()
        stats.grad_norm = grad_norm
        stats.loss = train_batch_res.loss.detach()
        stats.fps = torch.tensor(fps)
        train_stats.append(stats)

        if train_steps % FLAGS.il_eval_freq == 0:
            train_stats.apply(
                lambda _tensors: torch.stack(_tensors).mean().item())
            logger_str = ['[TRAIN] steps={}'.format(train_steps)]
            for k, v in train_stats.items():
                logger_str.append("{}: {:.4f}".format(k, v))
                writer.add_scalar('train/' + k, v, global_step=nb_frames)
            logging.info('\t'.join(logger_str))
            train_stats = DictList()
            writer.flush()
예제 #4
0
def main(training_folder):
    logging.info('Start compile...')
    dataloader = Dataloader(FLAGS.sketch_lengths, 0.2)
    model = compile.CompILE(vec_size=39,
                            hidden_size=FLAGS.hidden_size,
                            action_size=9,
                            env_arch=FLAGS.env_arch,
                            max_num_segments=FLAGS.compile_max_segs,
                            latent_dist=FLAGS.compile_latent,
                            beta_b=FLAGS.compile_beta_b,
                            beta_z=FLAGS.compile_beta_z,
                            prior_rate=FLAGS.compile_prior_rate,
                            dataloader=dataloader)
    if FLAGS.cuda:
        model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.compile_lr)

    train_steps = 0
    writer = SummaryWriter(training_folder)
    train_iter = dataloader.train_iter(batch_size=FLAGS.compile_batch_size)
    nb_frames = 0
    curr_best = np.inf
    train_stats = DictList()
    while True:
        if train_steps > FLAGS.compile_train_steps:
            logging.info('Reaching maximum steps')
            break

        if train_steps % FLAGS.compile_eval_freq == 0:
            # Testing
            val_metrics = {}
            model.eval()
            for env_name in FLAGS.sketch_lengths:
                val_metrics[env_name] = DictList()
                val_iter = dataloader.val_iter(
                    batch_size=FLAGS.compile_batch_size, env_names=[env_name])
                for val_batch, val_lengths, val_sketch_lens in val_iter:
                    if FLAGS.cuda:
                        val_batch.apply(lambda _t: _t.cuda())
                        val_lengths = val_lengths.cuda()
                        val_sketch_lens = val_sketch_lens.cuda()
                    with torch.no_grad():
                        val_outputs, extra_info = model.forward(
                            val_batch, val_lengths, val_sketch_lens)
                    val_metrics[env_name].append(val_outputs)

                # Parsing
                total_lengths = 0
                total_task_corrects = 0
                val_iter = dataloader.val_iter(batch_size=FLAGS.eval_episodes,
                                               env_names=[env_name],
                                               shuffle=True)
                val_batch, val_lengths, val_sketch_lens = val_iter.__next__()
                if FLAGS.cuda:
                    val_batch.apply(lambda _t: _t.cuda())
                    val_lengths = val_lengths.cuda()
                    val_sketch_lens = val_sketch_lens.cuda()
                with torch.no_grad():
                    val_outputs, extra_info = model.forward(
                        val_batch, val_lengths, val_sketch_lens)
                seg = torch.stack(extra_info['segment'], dim=1).argmax(-1)
                for batch_id, (length, sketch_length, _seg) in enumerate(
                        zip(val_lengths, val_sketch_lens, seg)):
                    traj = val_batch[batch_id]
                    traj = traj[:length]
                    _gt_subtask = traj.gt_onsets
                    target = point_of_change(_gt_subtask)
                    _seg = _seg[_seg.sort()[1]].cpu().tolist()

                    # Remove the last one because too trivial
                    val_metrics[env_name].append({
                        'f1_tol0': f1(target, _seg, 0),
                        'f1_tol1': f1(target, _seg, 1),
                        'f1_tol2': f1(target, _seg, 2)
                    })

                    # subtask
                    total_lengths += length.item()
                    _decoded_subtask = get_subtask_seq(
                        length.item(),
                        subtask=traj.tasks.tolist(),
                        use_ids=np.array(_seg))
                    total_task_corrects += (_gt_subtask.cpu(
                    ) == _decoded_subtask.cpu()).float().sum()

                # record task acc
                val_metrics[
                    env_name].task_acc = total_task_corrects / total_lengths

                # Print parsing result
                lines = []
                lines.append('tru_ids: {}'.format(target))
                lines.append('dec_ids: {}'.format(_seg))
                logging.info('\n'.join(lines))
                val_metrics[env_name].apply(
                    lambda _t: torch.tensor(_t).float().mean().item())

            # Logger
            for env_name, metric in val_metrics.items():
                line = ['[VALID][{}] steps={}'.format(env_name, train_steps)]
                for k, v in metric.items():
                    line.append('{}: {:.4f}'.format(k, v))
                logging.info('\t'.join(line))

            mean_val_metric = DictList()
            for metric in val_metrics.values():
                mean_val_metric.append(metric)
            mean_val_metric.apply(lambda t: torch.mean(torch.tensor(t)))
            for k, v in mean_val_metric.items():
                writer.add_scalar('val/' + k, v.item(), nb_frames)
            writer.flush()

            avg_loss = [val_metrics[env_name].loss for env_name in val_metrics]
            avg_loss = np.mean(avg_loss)
            if avg_loss < curr_best:
                curr_best = avg_loss
                logging.info('Save Best with loss: {}'.format(avg_loss))
                # Save the checkpoint
                with open(os.path.join(training_folder, 'bot_best.pkl'),
                          'wb') as f:
                    torch.save(model, f)

        model.train()
        train_batch, train_lengths, train_sketch_lens = train_iter.__next__()
        if FLAGS.cuda:
            train_batch.apply(lambda _t: _t.cuda())
            train_lengths = train_lengths.cuda()
            train_sketch_lens = train_sketch_lens.cuda()
        train_outputs, _ = model.forward(train_batch, train_lengths,
                                         train_sketch_lens)

        optimizer.zero_grad()
        train_outputs['loss'].backward()
        optimizer.step()
        train_steps += 1
        nb_frames += train_lengths.sum().item()

        train_outputs = DictList(train_outputs)
        train_outputs.apply(lambda _t: _t.item())
        train_stats.append(train_outputs)

        if train_steps % FLAGS.compile_eval_freq == 0:
            train_stats.apply(lambda _tensors: np.mean(_tensors))
            logger_str = ['[TRAIN] steps={}'.format(train_steps)]
            for k, v in train_stats.items():
                logger_str.append("{}: {:.4f}".format(k, v))
                writer.add_scalar('train/' + k, v, global_step=nb_frames)
            logging.info('\t'.join(logger_str))
            train_stats = DictList()
            writer.flush()