def evaluate_on_envs(bot, dataloader): val_metrics = {} bot.eval() envs = dataloader.env_names for env_name in envs: val_iter = dataloader.val_iter(batch_size=FLAGS.il_batch_size, env_names=[env_name], shuffle=True) output = DictList({}) total_lengths = 0 for batch, batch_lens, batch_sketch_lens in val_iter: if FLAGS.cuda: batch.apply(lambda _t: _t.cuda()) batch_lens = batch_lens.cuda() batch_sketch_lens = batch_sketch_lens.cuda() # Initialize memory with torch.no_grad(): #batch_results = run_batch(batch, batch_lens, batch_sketch_lens, bot, mode='val') start = time.time() batch_results, _ = bot.teacherforcing_batch( batch, batch_lens, batch_sketch_lens, recurrence=FLAGS.il_recurrence) end = time.time() print('batch time', end - start) batch_results.apply(lambda _t: _t.sum().item()) output.append(batch_results) total_lengths += batch_lens.sum().item() if FLAGS.debug: break output.apply(lambda _t: torch.tensor(_t).sum().item() / total_lengths) val_metrics[env_name] = {k: v for k, v in output.items()} # Parsing if 'om' in FLAGS.arch: with torch.no_grad(): parsing_stats, parsing_lines = parsing_loop( bot, dataloader=dataloader, batch_size=FLAGS.il_batch_size, cuda=FLAGS.cuda) for env_name in parsing_stats: parsing_stats[env_name].apply(lambda _t: np.mean(_t)) val_metrics[env_name].update(parsing_stats[env_name]) logging.info('Get parsing result') logging.info('\n' + '\n'.join(parsing_lines)) # evaluate on free run env #if not FLAGS.debug: # for sketch_length in val_metrics: # envs = [gym.make('jacopinpad-v0', sketch_length=sketch_length, # max_steps_per_sketch=FLAGS.max_steps_per_sketch) # for _ in range(FLAGS.eval_episodes)] # with torch.no_grad(): # free_run_metric = batch_evaluate(envs=envs, bot=bot, cuda=FLAGS.cuda) # val_metrics[sketch_length].update(free_run_metric) return val_metrics
def main(training_folder): logging.info('start taco...') dataloader = Dataloader(FLAGS.sketch_lengths, 0.2) model = ModularPolicy(nb_subtasks=10, input_dim=39, n_actions=9, a_mu=dataloader.a_mu, a_std=dataloader.a_std, s_mu=dataloader.s_mu, s_std=dataloader.s_std) if FLAGS.cuda: model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.taco_lr) train_steps = 0 writer = SummaryWriter(training_folder) train_iter = dataloader.train_iter(batch_size=FLAGS.taco_batch_size) nb_frames = 0 curr_best = np.inf train_stats = DictList() # test dataloader test_sketch_lengths = set(FLAGS.test_sketch_lengths) - set(FLAGS.sketch_lengths) test_dataloader = None if len(test_sketch_lengths) == 0 else Dataloader(test_sketch_lengths, FLAGS.il_val_ratio) scheduler = DropoutScheduler() while True: if train_steps > FLAGS.taco_train_steps: logging.info('Reaching maximum steps') break if train_steps % FLAGS.taco_eval_freq == 0: val_metrics = evaluate_loop(dataloader, model, dropout_p=scheduler.dropout_p) logging_metrics(nb_frames, train_steps, val_metrics, writer, 'val') if test_dataloader is not None: test_metrics = evaluate_loop(test_dataloader, model, dropout_p=scheduler.dropout_p) logging_metrics(nb_frames, train_steps, test_metrics, writer, 'test') avg_loss = [val_metrics[env_name].loss for env_name in val_metrics] avg_loss = np.mean(avg_loss) if avg_loss < curr_best: curr_best = avg_loss logging.info('Save Best with loss: {}'.format(avg_loss)) # Save the checkpoint with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f: torch.save(model, f) model.train() train_batch, train_lengths, train_subtask_lengths = train_iter.__next__() if FLAGS.cuda: train_batch.apply(lambda _t: _t.cuda()) train_lengths = train_lengths.cuda() train_subtask_lengths = train_subtask_lengths.cuda() start = time.time() train_outputs = teacherforce_batch(modular_p=model, trajs=train_batch, lengths=train_lengths, subtask_lengths=train_subtask_lengths, decode=False, dropout_p=scheduler.dropout_p) optimizer.zero_grad() train_outputs['loss'].backward() optimizer.step() train_steps += 1 scheduler.step() nb_frames += train_lengths.sum().item() end = time.time() fps = train_lengths.sum().item() / (end - start) train_outputs['fps'] = torch.tensor(fps) train_outputs = DictList(train_outputs) train_outputs.apply(lambda _t: _t.item()) train_stats.append(train_outputs) if train_steps % FLAGS.taco_eval_freq == 0: train_stats.apply(lambda _tensors: np.mean(_tensors)) logger_str = ['[TRAIN] steps={}'.format(train_steps)] for k, v in train_stats.items(): logger_str.append("{}: {:.4f}".format(k, v)) writer.add_scalar('train/' + k, v, global_step=nb_frames) logging.info('\t'.join(logger_str)) train_stats = DictList() writer.flush()
def main_loop(bot, dataloader, opt, training_folder, test_dataloader=None): # Prepare train_steps = 0 writer = SummaryWriter(training_folder) train_iter = dataloader.train_iter(batch_size=FLAGS.il_batch_size) nb_frames = 0 train_stats = DictList() curr_best = 100000 while True: if train_steps > FLAGS.il_train_steps: logging.info('Reaching maximum steps') break if train_steps % FLAGS.il_save_freq == 0: with open( os.path.join(training_folder, 'bot{}.pkl'.format(train_steps)), 'wb') as f: torch.save(bot, f) if train_steps % FLAGS.il_eval_freq == 0: # testing on valid val_metrics = evaluate_on_envs(bot, dataloader) logging_metrics(nb_frames, train_steps, val_metrics, writer, prefix='val') # testing on test env if test_dataloader is not None: test_metrics = evaluate_on_envs(bot, test_dataloader) logging_metrics(nb_frames, train_steps, test_metrics, writer, prefix='test') avg_loss = [ val_metrics[env_name]['loss'] for env_name in val_metrics ] avg_loss = np.mean(avg_loss) if avg_loss < curr_best: curr_best = avg_loss logging.info('Save Best with loss: {}'.format(avg_loss)) # Save the checkpoint with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f: torch.save(bot, f) # Forward/Backward bot.train() train_batch, train_lengths, train_sketch_lengths = train_iter.__next__( ) if FLAGS.cuda: train_batch.apply(lambda _t: _t.cuda()) train_lengths = train_lengths.cuda() train_sketch_lengths = train_sketch_lengths.cuda() start = time.time() #train_batch_res = run_batch(train_batch, train_lengths, train_sketch_lengths, bot) train_batch_res, _ = bot.teacherforcing_batch( train_batch, train_lengths, train_sketch_lengths, recurrence=FLAGS.il_recurrence) train_batch_res.apply(lambda _t: _t.sum() / train_lengths.sum()) batch_time = time.time() - start loss = train_batch_res.loss opt.zero_grad() loss.backward() params = [p for p in bot.parameters() if p.requires_grad] grad_norm = torch.nn.utils.clip_grad_norm_(parameters=params, max_norm=FLAGS.il_clip) opt.step() train_steps += 1 nb_frames += train_lengths.sum().item() fps = train_lengths.sum().item() / batch_time stats = DictList() stats.grad_norm = grad_norm stats.loss = train_batch_res.loss.detach() stats.fps = torch.tensor(fps) train_stats.append(stats) if train_steps % FLAGS.il_eval_freq == 0: train_stats.apply( lambda _tensors: torch.stack(_tensors).mean().item()) logger_str = ['[TRAIN] steps={}'.format(train_steps)] for k, v in train_stats.items(): logger_str.append("{}: {:.4f}".format(k, v)) writer.add_scalar('train/' + k, v, global_step=nb_frames) logging.info('\t'.join(logger_str)) train_stats = DictList() writer.flush()
def main(training_folder): logging.info('Start compile...') dataloader = Dataloader(FLAGS.sketch_lengths, 0.2) model = compile.CompILE(vec_size=39, hidden_size=FLAGS.hidden_size, action_size=9, env_arch=FLAGS.env_arch, max_num_segments=FLAGS.compile_max_segs, latent_dist=FLAGS.compile_latent, beta_b=FLAGS.compile_beta_b, beta_z=FLAGS.compile_beta_z, prior_rate=FLAGS.compile_prior_rate, dataloader=dataloader) if FLAGS.cuda: model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.compile_lr) train_steps = 0 writer = SummaryWriter(training_folder) train_iter = dataloader.train_iter(batch_size=FLAGS.compile_batch_size) nb_frames = 0 curr_best = np.inf train_stats = DictList() while True: if train_steps > FLAGS.compile_train_steps: logging.info('Reaching maximum steps') break if train_steps % FLAGS.compile_eval_freq == 0: # Testing val_metrics = {} model.eval() for env_name in FLAGS.sketch_lengths: val_metrics[env_name] = DictList() val_iter = dataloader.val_iter( batch_size=FLAGS.compile_batch_size, env_names=[env_name]) for val_batch, val_lengths, val_sketch_lens in val_iter: if FLAGS.cuda: val_batch.apply(lambda _t: _t.cuda()) val_lengths = val_lengths.cuda() val_sketch_lens = val_sketch_lens.cuda() with torch.no_grad(): val_outputs, extra_info = model.forward( val_batch, val_lengths, val_sketch_lens) val_metrics[env_name].append(val_outputs) # Parsing total_lengths = 0 total_task_corrects = 0 val_iter = dataloader.val_iter(batch_size=FLAGS.eval_episodes, env_names=[env_name], shuffle=True) val_batch, val_lengths, val_sketch_lens = val_iter.__next__() if FLAGS.cuda: val_batch.apply(lambda _t: _t.cuda()) val_lengths = val_lengths.cuda() val_sketch_lens = val_sketch_lens.cuda() with torch.no_grad(): val_outputs, extra_info = model.forward( val_batch, val_lengths, val_sketch_lens) seg = torch.stack(extra_info['segment'], dim=1).argmax(-1) for batch_id, (length, sketch_length, _seg) in enumerate( zip(val_lengths, val_sketch_lens, seg)): traj = val_batch[batch_id] traj = traj[:length] _gt_subtask = traj.gt_onsets target = point_of_change(_gt_subtask) _seg = _seg[_seg.sort()[1]].cpu().tolist() # Remove the last one because too trivial val_metrics[env_name].append({ 'f1_tol0': f1(target, _seg, 0), 'f1_tol1': f1(target, _seg, 1), 'f1_tol2': f1(target, _seg, 2) }) # subtask total_lengths += length.item() _decoded_subtask = get_subtask_seq( length.item(), subtask=traj.tasks.tolist(), use_ids=np.array(_seg)) total_task_corrects += (_gt_subtask.cpu( ) == _decoded_subtask.cpu()).float().sum() # record task acc val_metrics[ env_name].task_acc = total_task_corrects / total_lengths # Print parsing result lines = [] lines.append('tru_ids: {}'.format(target)) lines.append('dec_ids: {}'.format(_seg)) logging.info('\n'.join(lines)) val_metrics[env_name].apply( lambda _t: torch.tensor(_t).float().mean().item()) # Logger for env_name, metric in val_metrics.items(): line = ['[VALID][{}] steps={}'.format(env_name, train_steps)] for k, v in metric.items(): line.append('{}: {:.4f}'.format(k, v)) logging.info('\t'.join(line)) mean_val_metric = DictList() for metric in val_metrics.values(): mean_val_metric.append(metric) mean_val_metric.apply(lambda t: torch.mean(torch.tensor(t))) for k, v in mean_val_metric.items(): writer.add_scalar('val/' + k, v.item(), nb_frames) writer.flush() avg_loss = [val_metrics[env_name].loss for env_name in val_metrics] avg_loss = np.mean(avg_loss) if avg_loss < curr_best: curr_best = avg_loss logging.info('Save Best with loss: {}'.format(avg_loss)) # Save the checkpoint with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f: torch.save(model, f) model.train() train_batch, train_lengths, train_sketch_lens = train_iter.__next__() if FLAGS.cuda: train_batch.apply(lambda _t: _t.cuda()) train_lengths = train_lengths.cuda() train_sketch_lens = train_sketch_lens.cuda() train_outputs, _ = model.forward(train_batch, train_lengths, train_sketch_lens) optimizer.zero_grad() train_outputs['loss'].backward() optimizer.step() train_steps += 1 nb_frames += train_lengths.sum().item() train_outputs = DictList(train_outputs) train_outputs.apply(lambda _t: _t.item()) train_stats.append(train_outputs) if train_steps % FLAGS.compile_eval_freq == 0: train_stats.apply(lambda _tensors: np.mean(_tensors)) logger_str = ['[TRAIN] steps={}'.format(train_steps)] for k, v in train_stats.items(): logger_str.append("{}: {:.4f}".format(k, v)) writer.add_scalar('train/' + k, v, global_step=nb_frames) logging.info('\t'.join(logger_str)) train_stats = DictList() writer.flush()