def learn( flags, actor_model, model, batch, initial_agent_state, optimizer, scheduler, lock=threading.Lock(), # noqa: B008 ): """Performs a learning (optimization) step.""" with lock: """ put a lock on the central learner, send the trajectories to it. Update the parameters of the central learner, copy the parameters of the central learner back to the actors """ # print('RUNNING MAIN MODEL') # print('MODEL OUTOUT: ', model(batch, initial_agent_state)) # print('batch size : ', batch['frame'].size()) # TODO Next Step: think about chopping the sequences up for attending to long sequences # TODO Next Step: make the actors put only one sequence in one rollout. And then mask the remaining rollout # positions. And then change the triu in line 461 in TXL # Here entire 81 length sequence is considered as a query, and is autoregressively being attended to the # keys and values of length 81. This sequence length when becomes very large is when we'll need memory to # kick in learner_outputs, unused_state, unused_mems, mem_padding, ind_first_done = model( batch, initial_agent_state, mems=None, mem_padding=None) # Here mem_padding is same as "batch" padding for this iteration so can use # for masking loss # Take final value function slice for bootstrapping. # this is the final value from this trajectory if ind_first_done is not None: # B dimensional tensor bootstrap_value = learner_outputs["baseline"][ ind_first_done, range(flags.batch_size)] else: bootstrap_value = learner_outputs["baseline"][-1] # Move from obs[t] -> action[t] to action[t] -> obs[t]. batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } # Using learner_outputs to predict batch since batch is always one ahead of learner_outputs? rewards = batch["reward"] if flags.reward_clipping == "abs_one": clipped_rewards = torch.clamp(rewards, -1, 1) elif flags.reward_clipping == "none": clipped_rewards = rewards discounts = (~batch["done"]).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch["policy_logits"], target_policy_logits=learner_outputs[ "policy_logits"], # WHY IS THIS THE TARGET? actions=batch["action"], discounts=discounts, rewards=clipped_rewards, values=learner_outputs["baseline"], bootstrap_value=bootstrap_value, ) # TODO Next Step: the losses also have to be computed with the padding, think on a structure of mask # to do this efficiently # Advantages are [rollout_len, batch_size] # First we mask out vtrace_returns.pg_advantages where there is padding which fixes pg_loss pad_mask = (~(mem_padding.squeeze(0)[1:]) ).float() if mem_padding is not None else None pg_loss = compute_policy_gradient_loss( learner_outputs["policy_logits"], batch["action"], vtrace_returns.pg_advantages, pad_mask) baseline_loss = flags.baseline_cost * compute_baseline_loss( vtrace_returns.vs - learner_outputs["baseline"], pad_mask) entropy_loss = flags.entropy_cost * compute_entropy_loss( learner_outputs["policy_logits"], pad_mask) total_loss = pg_loss + baseline_loss + entropy_loss episode_returns = batch["episode_return"][batch["done"]] stats = { "episode_returns": tuple(episode_returns.cpu().numpy()), "mean_episode_return": torch.mean(episode_returns).item(), "total_loss": total_loss.item(), "pg_loss": pg_loss.item(), "baseline_loss": baseline_loss.item(), "entropy_loss": entropy_loss.item(), } optimizer.zero_grad() total_loss.backward() if flags.fp16: optimizer.clip_master_grads(flags.grad_norm_clipping) else: nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) optimizer.step() # scheduler is being stepped in the lock of batch_and_learn itself # scheduler.step() actor_model.load_state_dict(model.state_dict()) return stats
def learn( flags, actor_model, model, batch, initial_agent_state, optimizer, scheduler, lock=threading.Lock(), # noqa: B008 ): """Performs a learning (optimization) step.""" with lock: learner_outputs, unused_state = model(batch, initial_agent_state) # Take final value function slice for bootstrapping. bootstrap_value = learner_outputs["baseline"][-1] # Move from obs[t] -> action[t] to action[t] -> obs[t]. batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } rewards = batch["reward"] if flags.reward_clipping == "abs_one": clipped_rewards = torch.clamp(rewards, -1, 1) elif flags.reward_clipping == "none": clipped_rewards = rewards discounts = (~batch["done"]).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch["policy_logits"], target_policy_logits=learner_outputs["policy_logits"], actions=batch["action"], discounts=discounts, rewards=clipped_rewards, values=learner_outputs["baseline"], bootstrap_value=bootstrap_value, ) pg_loss = compute_policy_gradient_loss( learner_outputs["policy_logits"], batch["action"], vtrace_returns.pg_advantages, ) baseline_loss = flags.baseline_cost * compute_baseline_loss( vtrace_returns.vs - learner_outputs["baseline"]) entropy_loss = flags.entropy_cost * compute_entropy_loss( learner_outputs["policy_logits"]) total_loss = pg_loss + baseline_loss + entropy_loss episode_returns = batch["episode_return"][batch["done"]] stats = { "episode_returns": tuple(episode_returns.cpu().numpy()), "mean_episode_return": torch.mean(episode_returns).item(), "total_loss": total_loss.item(), "pg_loss": pg_loss.item(), "baseline_loss": baseline_loss.item(), "entropy_loss": entropy_loss.item(), } optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) optimizer.step() scheduler.step() actor_model.load_state_dict(model.state_dict()) return stats
def learn( flags, actor_model, model, batch, initial_agent_state, optimizer, scheduler, lock=threading.Lock(), # noqa: B008 ): """Performs a learning (optimization) step.""" with lock: """ put a lock on the central learner, send the trajectories to it. Update the parameters of the central learner, copy the parameters of the central learner back to the actors """ # TODO: Chop up batch into smaller pieces to run through TXL one at a time (caching previous as memory) # TODO: Change batch function to look for trajectories of similar lengths # TODO: Add in adaptive attention (and think of how things change (for ex no memory)) #print({key: batch[key].shape for key in batch}) mems, mem_padding = None, None for i in range(0, flags.unroll_length + 1, flags.chunk_size): mini_batch = { key: batch[key][i:i + flags.chunk_size] for key in batch if key != 'len_traj' } #Note that initial agent state isn't used by transformer (I think this is hidden state) #Will need to change if want to use this with LSTM #TODO : Need to change batch->minibatch (batch name gets overwritten) tmp_mask = torch.zeros_like(mini_batch["done"]).bool() learner_outputs, unused_state, mems, mem_padding, ind_first_done = model( mini_batch, initial_agent_state, mems=mems, mem_padding=mem_padding) #Here mem_padding is same as "batch" padding for this iteration so can use #for masking loss #if mini_batch["done"].any().item(): # print('Indfirstdone: ',ind_first_done) # print('miniBATCH DONE: ', mini_batch["done"]) # print('Mem padding: ', mem_padding) # Take final value function slice for bootstrapping. # this is the final value from this trajectory if ind_first_done is not None: # B dimensional tensor bootstrap_value = learner_outputs["baseline"][ ind_first_done, range(flags.batch_size)] else: bootstrap_value = learner_outputs["baseline"][-1] # Move from obs[t] -> action[t] to action[t] -> obs[t]. mini_batch = { key: tensor[1:] for key, tensor in mini_batch.items() } learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } #Using learner_outputs to predict batch since batch is always one ahead of learner_outputs? rewards = mini_batch["reward"] if flags.reward_clipping == "abs_one": clipped_rewards = torch.clamp(rewards, -1, 1) elif flags.reward_clipping == "none": clipped_rewards = rewards discounts = (~mini_batch["done"]).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=mini_batch["policy_logits"], target_policy_logits=learner_outputs[ "policy_logits"], #WHY IS THIS THE TARGET? actions=mini_batch["action"], discounts=discounts, rewards=clipped_rewards, values=learner_outputs["baseline"], bootstrap_value=bootstrap_value, ) # TODO Next Step: the losses also have to be computed with the padding, think on a structure of mask # to do this efficiently # Advantages are [rollout_len, batch_size] # First we mask out vtrace_returns.pg_advantages where there is padding which fixes pg_loss pad_mask = (~(mem_padding.squeeze(0)[1:]) ).float() if mem_padding is not None else None pg_loss = compute_policy_gradient_loss( learner_outputs["policy_logits"], mini_batch["action"], vtrace_returns.pg_advantages, pad_mask) baseline_loss = flags.baseline_cost * compute_baseline_loss( vtrace_returns.vs - learner_outputs["baseline"], pad_mask) entropy_loss = flags.entropy_cost * compute_entropy_loss( learner_outputs["policy_logits"], pad_mask) total_loss = pg_loss + baseline_loss + entropy_loss #tmp_mask is defined above if ind_first_done is not None: rows_to_use = [] cols_to_use = [] for i, val in enumerate(ind_first_done): if val != -1: rows_to_use.append(val) cols_to_use.append(i) tmp_mask[ rows_to_use, cols_to_use] = True #NOT RIGHT FOR COLS THAT DIDNT FINISH tmp_mask = tmp_mask[ 1:] #This is how they initially had it so will keep like this #if mini_batch["done"].any().item(): # print('TMP MASK: ',tmp_mask) # print('BATCH DONE: ', mini_batch["done"]) # print('shape1: {}, shape2: {}'.format(tmp_mask.shape, mini_batch['done'].shape)) #episode_returns = mini_batch["episode_return"][mini_batch["done"]] episode_returns = mini_batch["episode_return"][tmp_mask] stats = { "episode_returns": tuple(episode_returns.cpu().numpy()), "mean_episode_return": torch.mean(episode_returns).item(), "total_loss": total_loss.item(), "pg_loss": pg_loss.item(), "baseline_loss": baseline_loss.item(), "entropy_loss": entropy_loss.item(), } optimizer.zero_grad() total_loss.backward() if flags.fp16: optimizer.clip_master_grads(flags.grad_norm_clipping) else: nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) optimizer.step() # scheduler is being stepped in the lock of batch_and_learn itself # scheduler.step() actor_model.load_state_dict(model.state_dict()) return stats