def forward(self, x, memory, xbatch, mbatch): x = self.input_proj(x) # add input vectors to perform self-attention # also compute the batch indices and the edge indices for the self- # attention computation mem_cat = torch.cat([memory, x], 0) batch_cat = torch.cat([mbatch, xbatch], 0) t_edge_compute, (ei_cat, log_time_get_ei_from) = timer(gnns.utils.get_ei_from)(mbatch, xbatch) # transformer block t_mhsa, (mem_tmp, log_time_mhsa) = timer(self.mhsa)(mem_cat, batch_cat, ei_cat) t_norm1, mem_tmp = timer(self.norm1)(memory + mem_tmp) t_mlp, mem_update = timer(self.mlp)(mem_tmp) t_norm2, mem_update = timer(self.norm2)(mem_tmp + mem_update) # compute forget and input gates f, i = self.proj(torch.cat([memory, mem_update], -1)).chunk(2, -1) # update memory # this mechanism may be refined memory = memory * torch.sigmoid(f) + mem_update * torch.sigmoid(i) # for now the output is the memory output = memory log_time = {'t_edge_compute': t_edge_compute, 'details_edge_compute': log_time_get_ei_from, 't_mhsa': t_mhsa, 'details_mhsa': log_time_mhsa, 't_norm1': t_norm1, 't_mlp': t_mlp, 't_norm2': t_norm2} return output, memory, log_time
def forward(self, obs, memory, obs_batch, m_batch, instr_embedding=None): # if self.use_instr and instr_embedding is None: # instr_embedding = self._get_instr_embedding(obs.instr) # if self.use_instr and self.lang_model == "attgru": # # outputs: B x L x D # # memory: B x M # mask = (obs.instr != 0).float() # # The mask tensor has the same length as obs.instr, and # # thus can be both shorter and longer than instr_embedding. # # It can be longer if instr_embedding is computed # # for a subbatch of obs.instr. # # It can be shorter if obs.instr is a subbatch of # # the batch that instr_embeddings was computed for. # # Here, we make sure that mask and instr_embeddings # # have equal length along dimension 1. # mask = mask[:, :instr_embedding.shape[1]] # instr_embedding = instr_embedding[:, :mask.shape[1]] # # keys = self.memory2key(memory) # pre_softmax = (keys[:, None, :] * instr_embedding).sum(2) + 1000 * mask # attention = F.softmax(pre_softmax, dim=1) # instr_embedding = (instr_embedding * attention[:, :, None]).sum(1) t_slot_memory_model, (output, memory, log_time_slot_memory_model) = timer(self.slot_memory_model)(obs, memory, obs_batch, m_batch) t_scatter_sum, embedding = timer(scatter_sum)(output, m_batch.type(torch.LongTensor)) # if self.use_instr and not "filmcnn" in self.arch: # embedding = torch.cat((embedding, instr_embedding), dim=1) if hasattr(self, 'aux_info') and self.aux_info: extra_predictions = {info: self.extra_heads[info](embedding) for info in self.extra_heads} else: extra_predictions = dict() t_actor, x = timer(self.actor)(embedding) dist = Categorical(logits=F.log_softmax(x)) t_critic, x = timer(self.critic)(embedding) value = x log_time = {'t_slot_memory_model': t_slot_memory_model, 'details_memory_model': log_time_slot_memory_model, 't_scatter_sum': t_scatter_sum, 't_actor': t_actor, 't_critic': t_critic} return {'dist': dist, 'value': value, 'memory': memory, 'extra_predictions': extra_predictions, 'log_time': log_time}
def forward(self, x, batch, ei): # batch is the batch index tensor # TODO: implement for source and target tensors ? # that way we get rid of excess computation src, dest = ei B = batch[-1] + 1 H = self.nheads Fh = self.Fqk // H Fhv = self.Fv // H scaling = float(Fh) ** -0.5 q, k, v = self.proj(x).split([self.Fqk, self.Fqk, self.Fv], dim=-1) q = q * scaling q = q.reshape(-1, H, Fh) k = k.reshape(-1, H, Fh) v = v.reshape(-1, H, Fhv) # print(src, dest) qs, ks, vs = q[src], k[dest], v[dest] # dot product t0 = time.time() aw = qs.view(-1, H, 1, Fh) @ ks.view(-1, H, Fh, 1) t_mhsa_aw = time.time() - t0 aw = aw.squeeze() # softmax reduction t_mhsa_scatter_softmax, aw = timer(scatter_softmax)(aw, src) out = aw.view([-1, H, 1]) * vs t_mhsa_scatter_sum, out = timer(scatter_sum)(out, src) out = out.reshape([-1, H * Fhv]) log_time = {'t_mhsa_aw': t_mhsa_aw, 't_mhsa_scatter_softmax': t_mhsa_scatter_softmax, 't_mhsa_scatter_sum': t_mhsa_scatter_sum} return out, log_time
def update_parameters(self): # Collect experiences t_collect, (exps, logs) = timer(self.collect_experiences)() logs['t_collect'] = t_collect ''' exps is a DictList with the following keys ['obs', 'memory', 'mask', 'action', 'value', 'reward', 'advantage', 'returnn', 'log_prob'] and ['collected_info', 'extra_predictions'] if we use aux_info exps.obs is a DictList with the following keys ['image', 'instr'] exps.obj.image is a (n_procs * n_frames_per_proc) x image_size 4D tensor exps.obs.instr is a (n_procs * n_frames_per_proc) x (max number of words in an instruction) 2D tensor exps.memory is a (n_procs * n_frames_per_proc) x (memory_size = 2*image_embedding_size) 2D tensor exps.mask is (n_procs * n_frames_per_proc) x 1 2D tensor if we use aux_info: exps.collected_info and exps.extra_predictions are DictLists with keys being the added information. They are either (n_procs * n_frames_per_proc) 1D tensors or (n_procs * n_frames_per_proc) x k 2D tensors where k is the number of classes for multiclass classification ''' t0_train = time.time() t_details_train_forward_model = {} t_train_backward = 0 # objs[torch.sum(torch.stack([idx == i for i in indices]), dim=0).nonzero().flatten()] n = 0 for _ in range(self.epochs): n = n + 1 # Initialize log values log_entropies = [] log_values = [] log_policy_losses = [] log_value_losses = [] log_grad_norms = [] log_losses = [] ''' For each epoch, we create int(total_frames / batch_size + 1) batches, each of size batch_size (except maybe the last one. Each batch is divided into sub-batches of size recurrence (frames are contiguous in a sub-batch), but the position of each sub-batch in a batch and the position of each batch in the whole list of frames is random thanks to self._get_batches_starting_indexes(). ''' for inds in self._get_batches_starting_indexes(): # inds is a numpy array of indices that correspond to the beginning of a sub-batch # there are as many inds as there are batches # Initialize batch values batch_entropy = 0 batch_value = 0 batch_policy_loss = 0 batch_value_loss = 0 batch_loss = 0 # Initialize memory # Extract first memories inds_mem = [ item for sublist in [ list( range( self.acmodel.memory_dim[0] * i, self.acmodel.memory_dim[0] * i + self.acmodel.memory_dim[0])) for i in inds ] for item in sublist ] memory = exps.memory[inds_mem] all_obs_inds = exps.obs[1].image sb = DictList() for i in range(self.recurrence): # Extract scene level quantities: sb.action = exps.action[inds + i] sb.log_prob = exps.log_prob[inds + i] sb.advantage = exps.advantage[inds + i] sb.value = exps.value[inds + i] sb.returnn = exps.returnn[inds + i] m_batch = torch.IntTensor([ j + i for j in inds for _ in range(self.acmodel.memory_dim[0]) ]) # Extract subatch of observation and observation batch indices sb.obs = torch.zeros((0, self.acmodel.image_dim)) sb.obs_batch = torch.zeros(0).int() for j in inds + i: idx_j = all_obs_inds == j sb.obs = torch.cat([sb.obs, exps.obs[0].image[idx_j]], dim=0) sb.obs_batch = torch.cat( [sb.obs_batch, exps.obs[1].image[idx_j]], dim=0) # TODO rename obs[0] and obs[1] into obs.obs and obs.obs_batch # Reshape mask sb.mask = exps.mask[list( numpy.array(inds_mem) + self.acmodel.memory_dim[0] * (i))].flatten() # Compute loss model_results = self.acmodel(sb.obs, sb.mask.unsqueeze(1) * memory, sb.obs_batch, m_batch) dist = model_results['dist'] value = model_results['value'] memory = model_results['memory'] extra_predictions = model_results['extra_predictions'] entropy = dist.entropy().mean() t_details_train_forward_model = cumulate_value( t_details_train_forward_model, model_results['log_time']) ratio = torch.exp(dist.log_prob(sb.action) - sb.log_prob) surr1 = ratio * sb.advantage surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * sb.advantage policy_loss = -torch.min(surr1, surr2).mean() value_clipped = sb.value + torch.clamp( value - sb.value, -self.clip_eps, self.clip_eps) surr1 = (value - sb.returnn).pow(2) surr2 = (value_clipped - sb.returnn).pow(2) value_loss = torch.max(surr1, surr2).mean() loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss # Update batch values batch_entropy += entropy.item() batch_value += value.mean().item() batch_policy_loss += policy_loss.item() batch_value_loss += value_loss.item() batch_loss += loss # Update memories for next epoch if i < self.recurrence - 1: exps.memory[list( numpy.array(inds_mem) + self.acmodel.memory_dim[0] * (i + 1))] = memory.detach() # Update batch values batch_entropy /= self.recurrence batch_value /= self.recurrence batch_policy_loss /= self.recurrence batch_value_loss /= self.recurrence batch_loss /= self.recurrence # Update actor-critic t0_train_backward = time.time() self.optimizer.zero_grad() batch_loss.backward() grad_norm = sum( p.grad.data.norm(2)**2 for p in self.acmodel.parameters() if p.grad is not None)**0.5 torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(), self.max_grad_norm) self.optimizer.step() t_train_backward += time.time() - t0_train_backward # Update log values log_entropies.append(batch_entropy) log_values.append(batch_value) log_policy_losses.append(batch_policy_loss) log_value_losses.append(batch_value_loss) log_grad_norms.append(grad_norm.item()) log_losses.append(batch_loss.item()) t_train = time.time() - t0_train # Log some values logs["entropy"] = numpy.mean(log_entropies) logs["value"] = numpy.mean(log_values) logs["policy_loss"] = numpy.mean(log_policy_losses) logs["value_loss"] = numpy.mean(log_value_losses) logs["grad_norm"] = numpy.mean(log_grad_norms) logs["loss"] = numpy.mean(log_losses) logs['t_collect'] = t_collect logs['t_train'] = t_train logs['t_details_train_forward_mordel'] = t_details_train_forward_model logs['t_backward'] = t_train_backward return logs
def collect_experiences(self): """Collects rollouts and computes advantages. Runs several environments concurrently. The next actions are computed in a batch mode for all environments at the same time. The rollouts and advantages from all environments are concatenated together. Returns ------- exps : DictList Contains actions, rewards, advantages etc as attributes. Each attribute, e.g. `exps.reward` has a shape (self.num_frames_per_proc * num_envs, ...). k-th block of consecutive `self.num_frames_per_proc` frames contains data obtained from the k-th environment. Be careful not to mix data from different environments! logs : dict Useful stats about the training process, including the average reward, policy loss, value loss, etc. """ t0 = time.time() t_forward_process = 0 t_forward_step = 0 t_details_forward_model = {} for i in range(self.num_frames_per_proc): # Do one agent-environment interaction tt_process, preprocessed_obs = timer(self.preprocess_obss)( self.obs, device=self.device) t_forward_process += tt_process obs_flat = preprocessed_obs.image[0] obs_batch = preprocessed_obs.image[1] with torch.no_grad(): model_results = self.acmodel( obs_flat, self.mask.unsqueeze(1) * self.memory, obs_batch, self.m_batch) dist = model_results['dist'] value = model_results['value'].flatten() memory = model_results['memory'] extra_predictions = model_results['extra_predictions'] t_details_forward_model = cumulate_value(t_details_forward_model, model_results['log_time']) action = dist.sample() tt_step, (obs, reward, done, env_info) = timer(self.env.step)(action.cpu().numpy()) t_forward_step += tt_step if self.aux_info: env_info = self.aux_info_collector.process(env_info) # env_info = self.process_aux_info(env_info) # Update experiences values self.obss[i] = self.obs self.obs = obs self.memories[i] = self.memory self.memory = memory self.masks[i] = self.mask done_as_int = torch.tensor(done, device=self.device, dtype=torch.float).unsqueeze(1) self.mask = 1 - done_as_int.expand( done_as_int.shape[0], self.acmodel.memory_size[0]).flatten() self.actions[i] = action self.values[i] = value if self.reshape_reward is not None: self.rewards[i] = torch.tensor([ self.reshape_reward(obs_, action_, reward_, done_) for obs_, action_, reward_, done_ in zip( obs, action, reward, done) ], device=self.device) else: self.rewards[i] = torch.tensor(reward, device=self.device) self.log_probs[i] = dist.log_prob(action) if self.aux_info: self.aux_info_collector.fill_dictionaries( i, env_info, extra_predictions) # Update log values self.log_episode_return += torch.tensor(reward, device=self.device, dtype=torch.float) self.log_episode_reshaped_return += self.rewards[i] self.log_episode_num_frames += torch.ones(self.num_procs, device=self.device) for i, done_ in enumerate(done): if done_: self.log_done_counter += 1 self.log_return.append(self.log_episode_return[i].item()) self.log_reshaped_return.append( self.log_episode_reshaped_return[i].item()) self.log_num_frames.append( self.log_episode_num_frames[i].item()) episode_mask = torch.tensor([ self.mask[i * self.acmodel.memory_size[0]] for i in range(self.num_procs) ]) self.log_episode_return *= episode_mask self.log_episode_reshaped_return *= episode_mask self.log_episode_num_frames *= episode_mask t_collect_forward = time.time() - t0 # Add advantage and return to experiences t0 = time.time() preprocessed_obs = self.preprocess_obss(self.obs, device=self.device) with torch.no_grad(): # TODO: Add split obs_flat, obs_batch in preprocess_obss ? obs_flat = preprocessed_obs.image[0] obs_batch = preprocessed_obs.image[1] next_value = self.acmodel(obs_flat, self.mask.unsqueeze(1) * self.memory, obs_batch, self.m_batch)['value'].flatten() for i in reversed(range(self.num_frames_per_proc)): next_mask = torch.tensor([ self.masks[i + 1][j * self.acmodel.memory_size[0]] for j in range(self.num_procs) ]) if i < self.num_frames_per_proc - 1 else torch.tensor([ self.mask[j * self.acmodel.memory_size[0]] for j in range(self.num_procs) ]) next_value = self.values[ i + 1] if i < self.num_frames_per_proc - 1 else next_value next_advantage = self.advantages[ i + 1] if i < self.num_frames_per_proc - 1 else 0 delta = self.rewards[ i] + self.discount * next_value * next_mask - self.values[i] self.advantages[ i] = delta + self.discount * self.gae_lambda * next_advantage * next_mask t_collect_backward = time.time() - t0 # Flatten the data correctly, making sure that # each episode's data is a continuous chunk t0 = time.time() exps = DictList() exps.obs = [ self.obss[i][j] for j in range(self.num_procs) for i in range(self.num_frames_per_proc) ] # In commments below T is self.num_frames_per_proc, P is self.num_procs, # D is the dimensionality and M the number of memory slots # T x (P * M) x D -> T x P x M x D -> P x T x M x D -> (P * T * M) x D exps.memory = self.memories.reshape( (self.num_frames_per_proc, self.num_procs, self.acmodel.memory_size[0], self.acmodel.memory_size[1])).transpose(0, 1).reshape( -1, *self.memories.shape[2:]) # T x (P * M) -> T x P x M -> P x T x M -> (P * T * M) x 1 exps.mask = self.masks.reshape(self.num_frames_per_proc, self.num_procs, self.acmodel.memory_size[0]).transpose( 0, 1).reshape(-1).unsqueeze(1) # for all tensors below, T x P -> P x T -> P * T exps.action = self.actions.transpose(0, 1).reshape(-1) exps.value = self.values.transpose(0, 1).reshape(-1) exps.reward = self.rewards.transpose(0, 1).reshape(-1) exps.advantage = self.advantages.transpose(0, 1).reshape(-1) exps.returnn = exps.value + exps.advantage exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1) t_organize_exp = time.time() - t0 if self.aux_info: exps = self.aux_info_collector.end_collection(exps) # Preprocess experiences exps.obs = self.preprocess_obss(exps.obs, device=self.device) # Log some values keep = max(self.log_done_counter, self.num_procs) log = { "return_per_episode": self.log_return[-keep:], "reshaped_return_per_episode": self.log_reshaped_return[-keep:], "num_frames_per_episode": self.log_num_frames[-keep:], "num_frames": self.num_frames, "episodes_done": self.log_done_counter, "t_collect_forward": t_collect_forward, "t_details_forward_model": t_details_forward_model, "t_forward_process": t_forward_process, "t_forward_step": t_forward_step, "t_collect_backward": t_collect_backward, "t_collect_organize": t_organize_exp } self.log_done_counter = 0 self.log_return = self.log_return[-self.num_procs:] self.log_reshaped_return = self.log_reshaped_return[-self.num_procs:] self.log_num_frames = self.log_num_frames[-self.num_procs:] return exps, log