def batch_iter(self, trajs, batch_size, shuffle=True, epochs=-1) -> DictList: """ :param trajs: A list of DictList :param batch_size: int :param seq_len: int :param epochs: int. If -1, then forever :return: DictList [bsz, seq_len] """ epoch_iter = range(1, epochs + 1) if epochs > 0 else _forever() for _ in epoch_iter: if shuffle: random.shuffle(trajs) start_idx = 0 while start_idx < len(trajs): batch = DictList() lengths = [] task_lengths = [] for _traj in trajs[start_idx:start_idx + batch_size]: lengths.append(len(_traj.actions)) task_lengths.append(len(_traj.tasks)) _traj.apply(lambda _t: torch.tensor(_t)) batch.append(_traj) batch.apply(lambda _t: pad_sequence(_t, batch_first=True)) yield batch, torch.tensor(lengths), torch.tensor(task_lengths) start_idx += batch_size
def batch_evaluate(envs: List, bot: ModelBot, cuda, verbose=False) -> List: """ Return trajectories after roll out """ obs = DictList() for env in envs: obs.append(DictList(env.reset())) obs.apply(lambda _t: torch.tensor(_t).float()) actives = torch.tensor([i for i in range(len(envs))]) if cuda: obs.apply(lambda _t: _t.cuda()) actives = actives.cuda() trajs = [DictList() for _ in range(len(envs))] sketchs = obs.sketch.long()[0] sketch_lengths = torch.tensor(sketchs.shape, device=sketchs.device) mems = bot.init_memory( sketchs.unsqueeze(0).repeat(len(actives), 1), sketch_lengths.repeat(len(actives))) if bot.is_recurrent else None # Continue roll out while at least one active steps = 0 while len(actives) > 0: if verbose: print('active env:', len(actives)) active_trajs = [trajs[i] for i in actives] with torch.no_grad(): model_outputs = bot.get_action( obs.state, sketchs.unsqueeze(0).repeat(len(actives), 1), sketch_lengths.repeat(len(actives)), mems) actions = model_outputs.actions next_obs, rewards, dones = step_batch_envs(envs, actions, actives, cuda) transition = DictList({'rewards': rewards}) transition.update(obs) for idx, active_traj in enumerate(active_trajs): active_traj.append(transition[idx]) steps += 1 # Memory next_mems = None if bot.is_recurrent: next_mems = model_outputs.mems # For next step un_done_ids = (~dones).nonzero().squeeze(-1) obs = next_obs[un_done_ids] actives = actives[un_done_ids] mems = next_mems[un_done_ids] if next_mems is not None else None metric = DictList() for traj, env in zip(trajs, envs): traj.apply(lambda _tensors: torch.stack(_tensors)) metric.append({ 'ret': sum(env.local_score), 'succs': traj.rewards.sum().item(), 'length': len(traj.rewards) }) metric.apply(lambda _t: np.mean(_t)) return metric
def step_batch_envs(envs, actions, actives, cuda): """ Step a batch of envs. And detect if there are inactive/done envs return obss, rewards, dones of the active envs """ assert actions.shape[0] == len(actives) active_envs = [envs[i] for i in actives] obss = DictList() rewards = [] dones = [] for action, env in zip(actions, active_envs): obs, reward, done, _ = env.step(action.cpu().numpy()) obss.append(obs) rewards.append(reward) dones.append(done) obss.apply(lambda _t: torch.tensor(_t).float()) rewards = torch.tensor(rewards).float() dones = torch.tensor(dones) if cuda: obss.apply(lambda _t: _t.cuda()) rewards = rewards.cuda() dones = dones.cuda() # Update active return obss, rewards, dones
def evaluate_on_envs(bot, dataloader): val_metrics = {} bot.eval() envs = dataloader.env_names for env_name in envs: val_iter = dataloader.val_iter(batch_size=FLAGS.il_batch_size, env_names=[env_name], shuffle=True) output = DictList({}) total_lengths = 0 for batch, batch_lens, batch_sketch_lens in val_iter: if FLAGS.cuda: batch.apply(lambda _t: _t.cuda()) batch_lens = batch_lens.cuda() batch_sketch_lens = batch_sketch_lens.cuda() # Initialize memory with torch.no_grad(): #batch_results = run_batch(batch, batch_lens, batch_sketch_lens, bot, mode='val') start = time.time() batch_results, _ = bot.teacherforcing_batch( batch, batch_lens, batch_sketch_lens, recurrence=FLAGS.il_recurrence) end = time.time() print('batch time', end - start) batch_results.apply(lambda _t: _t.sum().item()) output.append(batch_results) total_lengths += batch_lens.sum().item() if FLAGS.debug: break output.apply(lambda _t: torch.tensor(_t).sum().item() / total_lengths) val_metrics[env_name] = {k: v for k, v in output.items()} # Parsing if 'om' in FLAGS.arch: with torch.no_grad(): parsing_stats, parsing_lines = parsing_loop( bot, dataloader=dataloader, batch_size=FLAGS.il_batch_size, cuda=FLAGS.cuda) for env_name in parsing_stats: parsing_stats[env_name].apply(lambda _t: np.mean(_t)) val_metrics[env_name].update(parsing_stats[env_name]) logging.info('Get parsing result') logging.info('\n' + '\n'.join(parsing_lines)) # evaluate on free run env #if not FLAGS.debug: # for sketch_length in val_metrics: # envs = [gym.make('jacopinpad-v0', sketch_length=sketch_length, # max_steps_per_sketch=FLAGS.max_steps_per_sketch) # for _ in range(FLAGS.eval_episodes)] # with torch.no_grad(): # free_run_metric = batch_evaluate(envs=envs, bot=bot, cuda=FLAGS.cuda) # val_metrics[sketch_length].update(free_run_metric) return val_metrics
def teacherforcing_batch(self, batch: DictList, batch_lengths, sketch_lengths, recurrence) -> (DictList, DictList): """ :param batch: DictList object [bsz, seqlen] :param batch_lengths: [bsz] :param sketch_lengths: [bsz] :param recurrence: an int :return: stats: A DictList of bsz, mem_size extra_info: A DictList of extra info """ bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1] sketchs = batch.tasks final_outputs = DictList({}) extra_info = DictList({}) mems = None if self.is_recurrent: mems = self.init_memory(sketchs, sketch_lengths) for t in range(seqlen): final_output = DictList({}) model_output = self.forward(batch.states[:, t], sketchs, sketch_lengths, mems) logprobs = model_output.dist.log_prob(batch.actions[:, t].float()) if 'log_end' in model_output: # p_end + (1 - pend) action_prob log_no_end_term = model_output.log_no_end + logprobs logprobs = torch.logsumexp(torch.stack( [model_output.log_end, log_no_end_term], dim=-1), dim=-1) final_output.log_end = model_output.log_end final_output.logprobs = logprobs if 'p' in model_output: extra_info.append({'p': model_output.p}) final_outputs.append(final_output) # Update memory next_mems = None if self.is_recurrent: next_mems = model_output.mems if (t + 1) % recurrence == 0: next_mems = next_mems.detach() mems = next_mems # Stack on time dim final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1)) extra_info.apply(lambda _tensors: torch.stack(_tensors, dim=1)) sequence_mask = torch.arange( batch_lengths.max().item(), device=batch_lengths.device)[None, :] < batch_lengths[:, None] final_outputs.loss = -final_outputs.logprobs if 'log_end' in final_outputs: batch_ids = torch.arange(bsz, device=batch.states.device) final_outputs.loss[batch_ids, batch_lengths - 1] = final_outputs.log_end[batch_ids, batch_lengths - 1] final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.)) return final_outputs, extra_info
def run_batch(batch: DictList, batch_lengths, sketch_lengths, bot: ModelBot, mode='train') \ -> (DictList, torch.Tensor): """ :param batch: DictList object [bsz, seqlen] :param bot: A model Bot :param mode: 'train' or 'eval' :return: stats: A DictList of bsz, mem_size """ bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1] sketchs = batch.tasks final_outputs = DictList({}) mems = None if bot.is_recurrent: mems = bot.init_memory(sketchs, sketch_lengths) for t in range(seqlen): final_output = DictList({}) model_output = bot.forward(batch.states[:, t], sketchs, sketch_lengths, mems) logprobs = model_output.dist.log_prob(batch.actions[:, t].float()) if 'log_end' in model_output: # p_end + (1 - pend) action_prob log_no_end_term = model_output.log_no_end + logprobs logprobs = torch.logsumexp(torch.stack( [model_output.log_end, log_no_end_term], dim=-1), dim=-1) final_output.log_end = model_output.log_end final_output.logprobs = logprobs final_outputs.append(final_output) # Update memory next_mems = None if bot.is_recurrent: next_mems = model_output.mems if (t + 1) % FLAGS.il_recurrence == 0 and mode == 'train': next_mems = next_mems.detach() mems = next_mems # Stack on time dim final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1)) sequence_mask = torch.arange( batch_lengths.max().item(), device=batch_lengths.device)[None, :] < batch_lengths[:, None] final_outputs.loss = -final_outputs.logprobs if 'log_end' in final_outputs: batch_ids = torch.arange(bsz, device=batch.states.device) final_outputs.loss[batch_ids, batch_lengths - 1] = final_outputs.log_end[batch_ids, batch_lengths - 1] final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.)) return final_outputs
def evaluate_on_env(modular_p: ModularPolicy, sketch_length, max_steps_per_sketch, use_sketch_id=False): start = time.time() env = gym.make('jacopinpad-v0', sketch_length=sketch_length, max_steps_per_sketch=max_steps_per_sketch) device = next(modular_p.parameters()).device modular_p.eval() obs = DictList(env.reset()) modular_p.reset(subtasks=obs.sketch) obs.apply(lambda _t: torch.tensor(_t, device=device).float()) done = False traj = DictList() try: while not done: if not use_sketch_id: action = modular_p.get_action(obs.state.unsqueeze(0)) else: action = modular_p.get_action(obs.state.unsqueeze(0), sketch_idx=int( obs.sketch_idx.item())) if action is not None: next_obs, reward, done, _ = env.step(action.cpu().numpy()[0]) transition = { 'reward': reward, 'action': action, 'features': obs.state } traj.append(transition) obs = DictList(next_obs) obs.apply(lambda _t: torch.tensor(_t, device=device).float()) else: done = True except MujocoException: pass end = time.time() if 'reward' in traj: return { 'succs': np.sum(traj.reward), 'episode_length': len(traj.reward), 'ret': sum(env.local_score), 'runtime': end - start } else: return { 'succs': 0, 'episode_length': 0, 'ret': 0, 'runtime': end - start }
def teacherforce_batch(modular_p: ModularPolicy, trajs: DictList, lengths, subtask_lengths, dropout_p, decode=False): """ Return log probs of a trajectory """ dropout_p = 0. if decode else dropout_p unique_tasks = set() for subtask in trajs.tasks: for task_id in subtask: task_id = task_id.item() if not task_id in unique_tasks: unique_tasks.add(task_id) unique_tasks = list(unique_tasks) # Forward for all unique task # task_results [bsz, length, all_tasks] states = trajs.states.float() targets = trajs.actions.float() all_task_results = DictList() for task in unique_tasks: all_task_results.append( modular_p.forward(task, states, targets, dropout_p=dropout_p)) all_task_results.apply(lambda _t: torch.stack(_t, dim=2)) # pad subtasks subtasks = trajs.tasks # results [bsz, len, nb_tasks] results = DictList() for batch_id, subtask in enumerate(subtasks): curr_result = DictList() for task in subtask: task_id = unique_tasks.index(task) curr_result.append(all_task_results[batch_id, :, task_id]) # [len, tasks] curr_result.apply(lambda _t: torch.stack(_t, dim=1)) results.append(curr_result) results.apply(lambda _t: torch.stack(_t, dim=0)) # Training if not decode: log_alphas = tac_forward_log(action_logprobs=results.action_logprobs, stop_logprobs=results.stop_logprobs, lengths=lengths, subtask_lengths=subtask_lengths) seq_logprobs = log_alphas[ torch.arange(log_alphas.shape[0], device=log_alphas.device), lengths - 1, subtask_lengths - 1] avg_logprobs = seq_logprobs.sum() / lengths.sum() return {'loss': -avg_logprobs} # Decode else: alphas, _ = tac_forward(action_logprobs=results.action_logprobs, stop_logprobs=results.stop_logprobs, lengths=lengths, subtask_lengths=subtask_lengths) decoded = alphas.argmax(-1) batch_ids = torch.arange(decoded.shape[0], device=decoded.device).unsqueeze(-1).repeat( 1, decoded.shape[1]) decoded_subtasks = subtasks[batch_ids, decoded] total_task_corrects = 0 for idx, (subtask, decoded_subtask, action, length, gt) in enumerate( zip(subtasks, decoded_subtasks, trajs.actions, lengths, trajs.gt_onsets)): _decoded_subtask = decoded_subtask[:length] _action = action[:length] gt = gt[:length] total_task_corrects += (gt == _decoded_subtask).float().sum() return { 'task_acc': total_task_corrects / lengths.sum() }, { 'tru': gt, 'act': _action, 'dec': _decoded_subtask }
def main(training_folder): logging.info('start taco...') dataloader = Dataloader(FLAGS.sketch_lengths, 0.2) model = ModularPolicy(nb_subtasks=10, input_dim=39, n_actions=9, a_mu=dataloader.a_mu, a_std=dataloader.a_std, s_mu=dataloader.s_mu, s_std=dataloader.s_std) if FLAGS.cuda: model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.taco_lr) train_steps = 0 writer = SummaryWriter(training_folder) train_iter = dataloader.train_iter(batch_size=FLAGS.taco_batch_size) nb_frames = 0 curr_best = np.inf train_stats = DictList() # test dataloader test_sketch_lengths = set(FLAGS.test_sketch_lengths) - set(FLAGS.sketch_lengths) test_dataloader = None if len(test_sketch_lengths) == 0 else Dataloader(test_sketch_lengths, FLAGS.il_val_ratio) scheduler = DropoutScheduler() while True: if train_steps > FLAGS.taco_train_steps: logging.info('Reaching maximum steps') break if train_steps % FLAGS.taco_eval_freq == 0: val_metrics = evaluate_loop(dataloader, model, dropout_p=scheduler.dropout_p) logging_metrics(nb_frames, train_steps, val_metrics, writer, 'val') if test_dataloader is not None: test_metrics = evaluate_loop(test_dataloader, model, dropout_p=scheduler.dropout_p) logging_metrics(nb_frames, train_steps, test_metrics, writer, 'test') avg_loss = [val_metrics[env_name].loss for env_name in val_metrics] avg_loss = np.mean(avg_loss) if avg_loss < curr_best: curr_best = avg_loss logging.info('Save Best with loss: {}'.format(avg_loss)) # Save the checkpoint with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f: torch.save(model, f) model.train() train_batch, train_lengths, train_subtask_lengths = train_iter.__next__() if FLAGS.cuda: train_batch.apply(lambda _t: _t.cuda()) train_lengths = train_lengths.cuda() train_subtask_lengths = train_subtask_lengths.cuda() start = time.time() train_outputs = teacherforce_batch(modular_p=model, trajs=train_batch, lengths=train_lengths, subtask_lengths=train_subtask_lengths, decode=False, dropout_p=scheduler.dropout_p) optimizer.zero_grad() train_outputs['loss'].backward() optimizer.step() train_steps += 1 scheduler.step() nb_frames += train_lengths.sum().item() end = time.time() fps = train_lengths.sum().item() / (end - start) train_outputs['fps'] = torch.tensor(fps) train_outputs = DictList(train_outputs) train_outputs.apply(lambda _t: _t.item()) train_stats.append(train_outputs) if train_steps % FLAGS.taco_eval_freq == 0: train_stats.apply(lambda _tensors: np.mean(_tensors)) logger_str = ['[TRAIN] steps={}'.format(train_steps)] for k, v in train_stats.items(): logger_str.append("{}: {:.4f}".format(k, v)) writer.add_scalar('train/' + k, v, global_step=nb_frames) logging.info('\t'.join(logger_str)) train_stats = DictList() writer.flush()
def main_loop(bot, dataloader, opt, training_folder, test_dataloader=None): # Prepare train_steps = 0 writer = SummaryWriter(training_folder) train_iter = dataloader.train_iter(batch_size=FLAGS.il_batch_size) nb_frames = 0 train_stats = DictList() curr_best = 100000 while True: if train_steps > FLAGS.il_train_steps: logging.info('Reaching maximum steps') break if train_steps % FLAGS.il_save_freq == 0: with open( os.path.join(training_folder, 'bot{}.pkl'.format(train_steps)), 'wb') as f: torch.save(bot, f) if train_steps % FLAGS.il_eval_freq == 0: # testing on valid val_metrics = evaluate_on_envs(bot, dataloader) logging_metrics(nb_frames, train_steps, val_metrics, writer, prefix='val') # testing on test env if test_dataloader is not None: test_metrics = evaluate_on_envs(bot, test_dataloader) logging_metrics(nb_frames, train_steps, test_metrics, writer, prefix='test') avg_loss = [ val_metrics[env_name]['loss'] for env_name in val_metrics ] avg_loss = np.mean(avg_loss) if avg_loss < curr_best: curr_best = avg_loss logging.info('Save Best with loss: {}'.format(avg_loss)) # Save the checkpoint with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f: torch.save(bot, f) # Forward/Backward bot.train() train_batch, train_lengths, train_sketch_lengths = train_iter.__next__( ) if FLAGS.cuda: train_batch.apply(lambda _t: _t.cuda()) train_lengths = train_lengths.cuda() train_sketch_lengths = train_sketch_lengths.cuda() start = time.time() #train_batch_res = run_batch(train_batch, train_lengths, train_sketch_lengths, bot) train_batch_res, _ = bot.teacherforcing_batch( train_batch, train_lengths, train_sketch_lengths, recurrence=FLAGS.il_recurrence) train_batch_res.apply(lambda _t: _t.sum() / train_lengths.sum()) batch_time = time.time() - start loss = train_batch_res.loss opt.zero_grad() loss.backward() params = [p for p in bot.parameters() if p.requires_grad] grad_norm = torch.nn.utils.clip_grad_norm_(parameters=params, max_norm=FLAGS.il_clip) opt.step() train_steps += 1 nb_frames += train_lengths.sum().item() fps = train_lengths.sum().item() / batch_time stats = DictList() stats.grad_norm = grad_norm stats.loss = train_batch_res.loss.detach() stats.fps = torch.tensor(fps) train_stats.append(stats) if train_steps % FLAGS.il_eval_freq == 0: train_stats.apply( lambda _tensors: torch.stack(_tensors).mean().item()) logger_str = ['[TRAIN] steps={}'.format(train_steps)] for k, v in train_stats.items(): logger_str.append("{}: {:.4f}".format(k, v)) writer.add_scalar('train/' + k, v, global_step=nb_frames) logging.info('\t'.join(logger_str)) train_stats = DictList() writer.flush()
def main(training_folder): logging.info('Start compile...') dataloader = Dataloader(FLAGS.sketch_lengths, 0.2) model = compile.CompILE(vec_size=39, hidden_size=FLAGS.hidden_size, action_size=9, env_arch=FLAGS.env_arch, max_num_segments=FLAGS.compile_max_segs, latent_dist=FLAGS.compile_latent, beta_b=FLAGS.compile_beta_b, beta_z=FLAGS.compile_beta_z, prior_rate=FLAGS.compile_prior_rate, dataloader=dataloader) if FLAGS.cuda: model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.compile_lr) train_steps = 0 writer = SummaryWriter(training_folder) train_iter = dataloader.train_iter(batch_size=FLAGS.compile_batch_size) nb_frames = 0 curr_best = np.inf train_stats = DictList() while True: if train_steps > FLAGS.compile_train_steps: logging.info('Reaching maximum steps') break if train_steps % FLAGS.compile_eval_freq == 0: # Testing val_metrics = {} model.eval() for env_name in FLAGS.sketch_lengths: val_metrics[env_name] = DictList() val_iter = dataloader.val_iter( batch_size=FLAGS.compile_batch_size, env_names=[env_name]) for val_batch, val_lengths, val_sketch_lens in val_iter: if FLAGS.cuda: val_batch.apply(lambda _t: _t.cuda()) val_lengths = val_lengths.cuda() val_sketch_lens = val_sketch_lens.cuda() with torch.no_grad(): val_outputs, extra_info = model.forward( val_batch, val_lengths, val_sketch_lens) val_metrics[env_name].append(val_outputs) # Parsing total_lengths = 0 total_task_corrects = 0 val_iter = dataloader.val_iter(batch_size=FLAGS.eval_episodes, env_names=[env_name], shuffle=True) val_batch, val_lengths, val_sketch_lens = val_iter.__next__() if FLAGS.cuda: val_batch.apply(lambda _t: _t.cuda()) val_lengths = val_lengths.cuda() val_sketch_lens = val_sketch_lens.cuda() with torch.no_grad(): val_outputs, extra_info = model.forward( val_batch, val_lengths, val_sketch_lens) seg = torch.stack(extra_info['segment'], dim=1).argmax(-1) for batch_id, (length, sketch_length, _seg) in enumerate( zip(val_lengths, val_sketch_lens, seg)): traj = val_batch[batch_id] traj = traj[:length] _gt_subtask = traj.gt_onsets target = point_of_change(_gt_subtask) _seg = _seg[_seg.sort()[1]].cpu().tolist() # Remove the last one because too trivial val_metrics[env_name].append({ 'f1_tol0': f1(target, _seg, 0), 'f1_tol1': f1(target, _seg, 1), 'f1_tol2': f1(target, _seg, 2) }) # subtask total_lengths += length.item() _decoded_subtask = get_subtask_seq( length.item(), subtask=traj.tasks.tolist(), use_ids=np.array(_seg)) total_task_corrects += (_gt_subtask.cpu( ) == _decoded_subtask.cpu()).float().sum() # record task acc val_metrics[ env_name].task_acc = total_task_corrects / total_lengths # Print parsing result lines = [] lines.append('tru_ids: {}'.format(target)) lines.append('dec_ids: {}'.format(_seg)) logging.info('\n'.join(lines)) val_metrics[env_name].apply( lambda _t: torch.tensor(_t).float().mean().item()) # Logger for env_name, metric in val_metrics.items(): line = ['[VALID][{}] steps={}'.format(env_name, train_steps)] for k, v in metric.items(): line.append('{}: {:.4f}'.format(k, v)) logging.info('\t'.join(line)) mean_val_metric = DictList() for metric in val_metrics.values(): mean_val_metric.append(metric) mean_val_metric.apply(lambda t: torch.mean(torch.tensor(t))) for k, v in mean_val_metric.items(): writer.add_scalar('val/' + k, v.item(), nb_frames) writer.flush() avg_loss = [val_metrics[env_name].loss for env_name in val_metrics] avg_loss = np.mean(avg_loss) if avg_loss < curr_best: curr_best = avg_loss logging.info('Save Best with loss: {}'.format(avg_loss)) # Save the checkpoint with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f: torch.save(model, f) model.train() train_batch, train_lengths, train_sketch_lens = train_iter.__next__() if FLAGS.cuda: train_batch.apply(lambda _t: _t.cuda()) train_lengths = train_lengths.cuda() train_sketch_lens = train_sketch_lens.cuda() train_outputs, _ = model.forward(train_batch, train_lengths, train_sketch_lens) optimizer.zero_grad() train_outputs['loss'].backward() optimizer.step() train_steps += 1 nb_frames += train_lengths.sum().item() train_outputs = DictList(train_outputs) train_outputs.apply(lambda _t: _t.item()) train_stats.append(train_outputs) if train_steps % FLAGS.compile_eval_freq == 0: train_stats.apply(lambda _tensors: np.mean(_tensors)) logger_str = ['[TRAIN] steps={}'.format(train_steps)] for k, v in train_stats.items(): logger_str.append("{}: {:.4f}".format(k, v)) writer.add_scalar('train/' + k, v, global_step=nb_frames) logging.info('\t'.join(logger_str)) train_stats = DictList() writer.flush()