def teacherforcing_batch(self, batch: DictList, batch_lengths, sketch_lengths, recurrence) -> (DictList, DictList): """ :param batch: DictList object [bsz, seqlen] :param batch_lengths: [bsz] :param sketch_lengths: [bsz] :param recurrence: an int :return: stats: A DictList of bsz, mem_size extra_info: A DictList of extra info """ bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1] sketchs = batch.tasks final_outputs = DictList({}) extra_info = DictList({}) mems = None if self.is_recurrent: mems = self.init_memory(sketchs, sketch_lengths) for t in range(seqlen): final_output = DictList({}) model_output = self.forward(batch.states[:, t], sketchs, sketch_lengths, mems) logprobs = model_output.dist.log_prob(batch.actions[:, t].float()) if 'log_end' in model_output: # p_end + (1 - pend) action_prob log_no_end_term = model_output.log_no_end + logprobs logprobs = torch.logsumexp(torch.stack( [model_output.log_end, log_no_end_term], dim=-1), dim=-1) final_output.log_end = model_output.log_end final_output.logprobs = logprobs if 'p' in model_output: extra_info.append({'p': model_output.p}) final_outputs.append(final_output) # Update memory next_mems = None if self.is_recurrent: next_mems = model_output.mems if (t + 1) % recurrence == 0: next_mems = next_mems.detach() mems = next_mems # Stack on time dim final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1)) extra_info.apply(lambda _tensors: torch.stack(_tensors, dim=1)) sequence_mask = torch.arange( batch_lengths.max().item(), device=batch_lengths.device)[None, :] < batch_lengths[:, None] final_outputs.loss = -final_outputs.logprobs if 'log_end' in final_outputs: batch_ids = torch.arange(bsz, device=batch.states.device) final_outputs.loss[batch_ids, batch_lengths - 1] = final_outputs.log_end[batch_ids, batch_lengths - 1] final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.)) return final_outputs, extra_info
def run_batch(batch: DictList, batch_lengths, sketch_lengths, bot: ModelBot, mode='train') \ -> (DictList, torch.Tensor): """ :param batch: DictList object [bsz, seqlen] :param bot: A model Bot :param mode: 'train' or 'eval' :return: stats: A DictList of bsz, mem_size """ bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1] sketchs = batch.tasks final_outputs = DictList({}) mems = None if bot.is_recurrent: mems = bot.init_memory(sketchs, sketch_lengths) for t in range(seqlen): final_output = DictList({}) model_output = bot.forward(batch.states[:, t], sketchs, sketch_lengths, mems) logprobs = model_output.dist.log_prob(batch.actions[:, t].float()) if 'log_end' in model_output: # p_end + (1 - pend) action_prob log_no_end_term = model_output.log_no_end + logprobs logprobs = torch.logsumexp(torch.stack( [model_output.log_end, log_no_end_term], dim=-1), dim=-1) final_output.log_end = model_output.log_end final_output.logprobs = logprobs final_outputs.append(final_output) # Update memory next_mems = None if bot.is_recurrent: next_mems = model_output.mems if (t + 1) % FLAGS.il_recurrence == 0 and mode == 'train': next_mems = next_mems.detach() mems = next_mems # Stack on time dim final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1)) sequence_mask = torch.arange( batch_lengths.max().item(), device=batch_lengths.device)[None, :] < batch_lengths[:, None] final_outputs.loss = -final_outputs.logprobs if 'log_end' in final_outputs: batch_ids = torch.arange(bsz, device=batch.states.device) final_outputs.loss[batch_ids, batch_lengths - 1] = final_outputs.log_end[batch_ids, batch_lengths - 1] final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.)) return final_outputs
def main_loop(bot, dataloader, opt, training_folder, test_dataloader=None): # Prepare train_steps = 0 writer = SummaryWriter(training_folder) train_iter = dataloader.train_iter(batch_size=FLAGS.il_batch_size) nb_frames = 0 train_stats = DictList() curr_best = 100000 while True: if train_steps > FLAGS.il_train_steps: logging.info('Reaching maximum steps') break if train_steps % FLAGS.il_save_freq == 0: with open( os.path.join(training_folder, 'bot{}.pkl'.format(train_steps)), 'wb') as f: torch.save(bot, f) if train_steps % FLAGS.il_eval_freq == 0: # testing on valid val_metrics = evaluate_on_envs(bot, dataloader) logging_metrics(nb_frames, train_steps, val_metrics, writer, prefix='val') # testing on test env if test_dataloader is not None: test_metrics = evaluate_on_envs(bot, test_dataloader) logging_metrics(nb_frames, train_steps, test_metrics, writer, prefix='test') avg_loss = [ val_metrics[env_name]['loss'] for env_name in val_metrics ] avg_loss = np.mean(avg_loss) if avg_loss < curr_best: curr_best = avg_loss logging.info('Save Best with loss: {}'.format(avg_loss)) # Save the checkpoint with open(os.path.join(training_folder, 'bot_best.pkl'), 'wb') as f: torch.save(bot, f) # Forward/Backward bot.train() train_batch, train_lengths, train_sketch_lengths = train_iter.__next__( ) if FLAGS.cuda: train_batch.apply(lambda _t: _t.cuda()) train_lengths = train_lengths.cuda() train_sketch_lengths = train_sketch_lengths.cuda() start = time.time() #train_batch_res = run_batch(train_batch, train_lengths, train_sketch_lengths, bot) train_batch_res, _ = bot.teacherforcing_batch( train_batch, train_lengths, train_sketch_lengths, recurrence=FLAGS.il_recurrence) train_batch_res.apply(lambda _t: _t.sum() / train_lengths.sum()) batch_time = time.time() - start loss = train_batch_res.loss opt.zero_grad() loss.backward() params = [p for p in bot.parameters() if p.requires_grad] grad_norm = torch.nn.utils.clip_grad_norm_(parameters=params, max_norm=FLAGS.il_clip) opt.step() train_steps += 1 nb_frames += train_lengths.sum().item() fps = train_lengths.sum().item() / batch_time stats = DictList() stats.grad_norm = grad_norm stats.loss = train_batch_res.loss.detach() stats.fps = torch.tensor(fps) train_stats.append(stats) if train_steps % FLAGS.il_eval_freq == 0: train_stats.apply( lambda _tensors: torch.stack(_tensors).mean().item()) logger_str = ['[TRAIN] steps={}'.format(train_steps)] for k, v in train_stats.items(): logger_str.append("{}: {:.4f}".format(k, v)) writer.add_scalar('train/' + k, v, global_step=nb_frames) logging.info('\t'.join(logger_str)) train_stats = DictList() writer.flush()