class PpoPlayerContinuous(BasePlayer): def __init__(self, config): BasePlayer.__init__(self, config) self.network = config['network'] self.actions_num = self.action_space.shape[0] self.actions_low = torch.from_numpy( self.action_space.low.copy()).float().to(self.device) self.actions_high = torch.from_numpy( self.action_space.high.copy()).float().to(self.device) self.mask = [False] self.normalize_input = self.config['normalize_input'] obs_shape = self.obs_shape config = { 'actions_num': self.actions_num, 'input_shape': obs_shape, 'num_seqs': self.num_agents } self.model = self.network.build(config) self.model.to(self.device) self.model.eval() self.is_rnn = self.model.is_rnn() if self.normalize_input: self.running_mean_std = RunningMeanStd(obs_shape).to(self.device) self.running_mean_std.eval() def get_action(self, obs, is_determenistic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) obs = self._preproc_obs(obs) input_dict = { 'is_train': False, 'prev_actions': None, 'obs': obs, 'rnn_states': self.states } with torch.no_grad(): res_dict = self.model(input_dict) mu = res_dict['mus'] action = res_dict['actions'] self.states = res_dict['rnn_states'] if is_determenistic: current_action = mu else: current_action = action current_action = torch.squeeze(current_action.detach()) return rescale_actions(self.actions_low, self.actions_high, torch.clamp(current_action, -1.0, 1.0)) def restore(self, fn): checkpoint = torch_ext.load_checkpoint(fn) self.model.load_state_dict(checkpoint['model']) if self.normalize_input: self.running_mean_std.load_state_dict( checkpoint['running_mean_std']) def reset(self): self.init_rnn()
class PpoPlayerDiscrete(BasePlayer): def __init__(self, config): BasePlayer.__init__(self, config) self.network = config['network'] if type(self.action_space) is gym.spaces.Discrete: self.actions_num = self.action_space.n self.is_multi_discrete = False if type(self.action_space) is gym.spaces.Tuple: self.actions_num = [action.n for action in self.action_space] self.is_multi_discrete = True self.mask = [False] self.normalize_input = self.config['normalize_input'] obs_shape = self.obs_shape config = { 'actions_num': self.actions_num, 'input_shape': obs_shape, 'num_seqs': self.num_agents, 'value_size': self.value_size } self.model = self.network.build(config) self.model.to(self.device) self.model.eval() self.is_rnn = self.model.is_rnn() if self.normalize_input: self.running_mean_std = RunningMeanStd(obs_shape).to(self.device) self.running_mean_std.eval() def get_masked_action(self, obs, action_masks, is_determenistic=True): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) obs = self._preproc_obs(obs) action_masks = torch.Tensor(action_masks).to(self.device) input_dict = { 'is_train': False, 'prev_actions': None, 'obs': obs, 'action_masks': action_masks, 'rnn_states': self.states } self.model.eval() with torch.no_grad(): neglogp, value, action, logits, self.states = self.model( input_dict) logits = res_dict['logits'] action = res_dict['actions'] self.states = res_dict['rnn_states'] if self.is_multi_discrete: if is_determenistic: action = [ torch.argmax(logit.detach(), axis=-1).squeeze() for logit in logits ] return torch.stack(action, dim=-1) else: return action.squeeze().detach() else: if is_determenistic: return torch.argmax(logits.detach(), axis=-1).squeeze() else: return action.squeeze().detach() def get_action(self, obs, is_determenistic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) obs = self._preproc_obs(obs) self.model.eval() input_dict = { 'is_train': False, 'prev_actions': None, 'obs': obs, 'rnn_states': self.states } with torch.no_grad(): res_dict = self.model(input_dict) logits = res_dict['logits'] action = res_dict['actions'] self.states = res_dict['rnn_states'] if self.is_multi_discrete: if is_determenistic: action = [ torch.argmax(logit.detach(), axis=1).squeeze() for logit in logits ] return torch.stack(action, dim=-1) else: return action.squeeze().detach() else: if is_determenistic: return torch.argmax(logits.detach(), axis=-1).squeeze() else: return action.squeeze().detach() def restore(self, fn): checkpoint = torch_ext.load_checkpoint(fn) self.model.load_state_dict(checkpoint['model']) if self.normalize_input: self.running_mean_std.load_state_dict( checkpoint['running_mean_std']) def reset(self): self.init_rnn()
class A2CBase: def __init__(self, base_name, config): self.config = config self.env_config = config.get('env_config', {}) self.num_actors = config['num_actors'] self.env_name = config['env_name'] self.env_info = config.get('env_info') if self.env_info is None: self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config) self.env_info = self.vec_env.get_env_info() self.ppo_device = config.get('device', 'cuda:0') print('Env info:') print(self.env_info) self.value_size = self.env_info.get('value_size',1) self.observation_space = self.env_info['observation_space'] self.weight_decay = config.get('weight_decay', 0.0) self.use_action_masks = config.get('use_action_masks', False) self.is_train = config.get('is_train', True) self.central_value_config = self.config.get('central_value_config', None) self.has_central_value = self.central_value_config is not None if self.has_central_value: self.state_space = self.env_info.get('state_space', None) self.state_shape = None if self.state_space.shape != None: self.state_shape = self.state_space.shape self.self_play_config = self.config.get('self_play_config', None) self.has_self_play_config = self.self_play_config is not None self.self_play = config.get('self_play', False) self.save_freq = config.get('save_frequency', 0) self.save_best_after = config.get('save_best_after', 100) self.print_stats = config.get('print_stats', True) self.rnn_states = None self.name = base_name self.ppo = config['ppo'] self.max_epochs = self.config.get('max_epochs', 1e6) self.is_adaptive_lr = config['lr_schedule'] == 'adaptive' self.linear_lr = config['lr_schedule'] == 'linear' self.schedule_type = config.get('schedule_type', 'legacy') if self.is_adaptive_lr: self.lr_threshold = config['lr_threshold'] self.scheduler = schedulers.AdaptiveScheduler(self.lr_threshold) elif self.linear_lr: self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']), max_steps=self.max_epochs, apply_to_entropy=config.get('schedule_entropy', False), start_entropy_coef=config.get('entropy_coef')) else: self.scheduler = schedulers.IdentityScheduler() self.e_clip = config['e_clip'] self.clip_value = config['clip_value'] self.network = config['network'] self.rewards_shaper = config['reward_shaper'] self.num_agents = self.env_info.get('agents', 1) self.steps_num = config['steps_num'] self.seq_len = self.config.get('seq_length', 4) self.normalize_advantage = config['normalize_advantage'] self.normalize_input = self.config['normalize_input'] self.normalize_value = self.config.get('normalize_value', False) self.obs_shape = self.observation_space.shape self.critic_coef = config['critic_coef'] self.grad_norm = config['grad_norm'] self.gamma = self.config['gamma'] self.tau = self.config['tau'] self.games_to_track = self.config.get('games_to_track', 100) self.game_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device) self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device) self.obs = None self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation self.batch_size = self.steps_num * self.num_actors * self.num_agents self.batch_size_envs = self.steps_num * self.num_actors self.minibatch_size = self.config['minibatch_size'] self.mini_epochs_num = self.config['mini_epochs'] self.num_minibatches = self.batch_size // self.minibatch_size assert(self.batch_size % self.minibatch_size == 0) self.last_lr = self.config['learning_rate'] self.frame = 0 self.update_time = 0 self.last_mean_rewards = -100500 self.play_time = 0 self.epoch_num = 0 self.entropy_coef = self.config['entropy_coef'] self.writer = SummaryWriter('runs/' + config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) if self.normalize_value: self.value_mean_std = RunningMeanStd((1,)).to(self.ppo_device) self.is_tensor_obses = False self.last_rnn_indices = None self.last_state_indices = None #self_play if self.has_self_play_config: print('Initializing SelfPlay Manager') self.self_play_manager = SelfPlayManager(self.self_play_config, self.writer) # features self.algo_observer = config['features']['observer'] def set_eval(self): self.model.eval() if self.normalize_input: self.running_mean_std.eval() if self.normalize_value: value = self.value_mean_std.eval() def set_train(self): self.model.train() if self.normalize_input: self.running_mean_std.train() if self.normalize_value: value = self.value_mean_std.train() def update_lr(self, lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr def get_action_values(self, obs): processed_obs = self._preproc_obs(obs['obs']) self.model.eval() input_dict = { 'is_train': False, 'prev_actions': None, 'obs' : processed_obs, 'rnn_states' : self.rnn_states } with torch.no_grad(): res_dict = self.model(input_dict) if self.has_central_value: states = obs['states'] input_dict = { 'is_train': False, 'states' : states, #'actions' : res_dict['action'], #'rnn_states' : self.rnn_states } value = self.get_central_value(input_dict) res_dict['value'] = value if self.normalize_value: res_dict['value'] = self.value_mean_std(res_dict['value'], True) return res_dict def get_values(self, obs): with torch.no_grad(): if self.has_central_value: states = obs['states'] self.central_value_net.eval() input_dict = { 'is_train': False, 'states' : states, 'actions' : None, 'is_done': self.dones, } value = self.get_central_value(input_dict) else: self.model.eval() processed_obs = self._preproc_obs(obs['obs']) input_dict = { 'is_train': False, 'prev_actions': None, 'obs' : processed_obs, 'rnn_states' : self.rnn_states } result = self.model(input_dict) value = result['value'] if self.normalize_value: value = self.value_mean_std(value, True) return value def reset_envs(self): self.obs = self.env_reset() def init_tensors(self): if self.observation_space.dtype == np.uint8: torch_dtype = torch.uint8 else: torch_dtype = torch.float32 batch_size = self.num_agents * self.num_actors val_shape = (self.steps_num, batch_size, self.value_size) current_rewards_shape = (batch_size, self.value_size) self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device) self.dones = torch.zeros((batch_size,), dtype=torch.uint8, device=self.ppo_device) self.mb_obs = torch.zeros((self.steps_num, batch_size) + self.obs_shape, dtype=torch_dtype, device=self.ppo_device) if self.has_central_value: self.mb_vobs = torch.zeros((self.steps_num, self.num_actors) + self.state_shape, dtype=torch_dtype, device=self.ppo_device) self.mb_rewards = torch.zeros(val_shape, dtype = torch.float32, device=self.ppo_device) self.mb_values = torch.zeros(val_shape, dtype = torch.float32, device=self.ppo_device) self.mb_dones = torch.zeros((self.steps_num, batch_size), dtype = torch.uint8, device=self.ppo_device) self.mb_neglogpacs = torch.zeros((self.steps_num, batch_size), dtype = torch.float32, device=self.ppo_device) if self.is_rnn: self.rnn_states = self.model.get_default_rnn_state() self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states] batch_size = self.num_agents * self.num_actors num_seqs = self.steps_num * batch_size // self.seq_len assert((self.steps_num * batch_size // self.num_minibatches) % self.seq_len == 0) self.mb_rnn_states = [torch.zeros((s.size()[0], num_seqs, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states] def init_rnn_from_model(self, model): self.is_rnn = self.model.is_rnn() def init_rnn_step(self, batch_size, mb_rnn_states): mb_rnn_states = self.mb_rnn_states mb_rnn_masks = torch.zeros(self.steps_num*batch_size, dtype = torch.float32, device=self.ppo_device) steps_mask = torch.arange(0, batch_size * self.steps_num, self.steps_num, dtype=torch.long, device=self.ppo_device) play_mask = torch.arange(0, batch_size, 1, dtype=torch.long, device=self.ppo_device) steps_state = torch.arange(0, batch_size * self.steps_num//self.seq_len, self.steps_num//self.seq_len, dtype=torch.long, device=self.ppo_device) indices = torch.zeros((batch_size), dtype = torch.long, device=self.ppo_device) return mb_rnn_masks, indices, steps_mask, steps_state, play_mask, mb_rnn_states def process_rnn_indices(self, mb_rnn_masks, indices, steps_mask, steps_state, mb_rnn_states): seq_indices = None if indices.max().item() >= self.steps_num: return seq_indices, True mb_rnn_masks[indices + steps_mask] = 1 seq_indices = indices % self.seq_len state_indices = (seq_indices == 0).nonzero(as_tuple=False) state_pos = indices // self.seq_len rnn_indices = state_pos[state_indices] + steps_state[state_indices] for s, mb_s in zip(self.rnn_states, mb_rnn_states): mb_s[:, rnn_indices, :] = s[:, state_indices, :] self.last_rnn_indices = rnn_indices self.last_state_indices = state_indices return seq_indices, False def process_rnn_dones(self, all_done_indices, indices, seq_indices): if len(all_done_indices) > 0: shifts = self.seq_len - 1 - seq_indices[all_done_indices] indices[all_done_indices] += shifts for s in self.rnn_states: s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0 indices += 1 def cast_obs(self, obs): if isinstance(obs, torch.Tensor): self.is_tensor_obses = True elif isinstance(obs, np.ndarray): assert(self.observation_space.dtype != np.int8) if self.observation_space.dtype == np.uint8: obs = torch.ByteTensor(obs).to(self.ppo_device) else: obs = torch.FloatTensor(obs).to(self.ppo_device) return obs def obs_to_tensors(self, obs): if isinstance(obs, dict): upd_obs = {} for key, value in obs.items(): upd_obs[key] = self.cast_obs(value) else: upd_obs = {'obs' : self.cast_obs(obs)} return upd_obs def preprocess_actions(self, actions): if not self.is_tensor_obses: actions = actions.cpu().numpy() return actions def env_step(self, actions): actions = self.preprocess_actions(actions) obs, rewards, dones, infos = self.vec_env.step(actions) if self.is_tensor_obses: if self.value_size == 1: rewards = rewards.unsqueeze(1) return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos else: if self.value_size == 1: rewards = np.expand_dims(rewards, axis=1) return self.obs_to_tensors(obs), torch.from_numpy(rewards).to(self.ppo_device).float(), torch.from_numpy(dones).to(self.ppo_device), infos def env_reset(self): obs = self.vec_env.reset() obs = self.obs_to_tensors(obs) return obs def discount_values(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards): lastgaelam = 0 mb_advs = torch.zeros_like(mb_rewards) for t in reversed(range(self.steps_num)): if t == self.steps_num - 1: nextnonterminal = 1.0 - fdones nextvalues = last_extrinsic_values else: nextnonterminal = 1.0 - mb_fdones[t+1] nextvalues = mb_extrinsic_values[t+1] nextnonterminal = nextnonterminal.unsqueeze(1) delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_extrinsic_values[t] mb_advs[t] = lastgaelam = delta + self.gamma * self.tau * nextnonterminal * lastgaelam return mb_advs def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_masks): lastgaelam = 0 mb_advs = torch.zeros_like(mb_rewards) for t in reversed(range(self.steps_num)): if t == self.steps_num - 1: nextnonterminal = 1.0 - fdones nextvalues = last_extrinsic_values else: nextnonterminal = 1.0 - mb_fdones[t+1] nextvalues = mb_extrinsic_values[t+1] nextnonterminal = nextnonterminal.unsqueeze(1) delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_extrinsic_values[t] mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * mb_masks[t].unsqueeze(1) return mb_advs def clear_stats(self): batch_size = self.num_agents * self.num_actors self.game_rewards.clear() self.game_lengths.clear() self.last_mean_rewards = -100500 self.algo_observer.after_clear_stats() def update_epoch(self): pass def train(self): pass def prepare_dataset(self, batch_dict): pass def train_epoch(self): pass def train_actor_critic(self, obs_dict, opt_step=True): pass def calc_gradients(self, opt_step): pass def get_central_value(self, obs_dict): return self.central_value_net.get_value(obs_dict) def train_central_value(self): return self.central_value_net.train_net() def get_full_state_weights(self): state = self.get_weights() state['epoch'] = self.epoch_num state['optimizer'] = self.optimizer.state_dict() if self.has_central_value: state['assymetric_vf_nets'] = self.central_value_net.state_dict() return state def set_full_state_weights(self, weights): self.set_weights(weights) self.epoch_num = weights['epoch'] if self.has_central_value: self.central_value_net.load_state_dict(weights['assymetric_vf_nets']) self.optimizer.load_state_dict(weights['optimizer']) def get_weights(self): state = {'model': self.model.state_dict()} if self.normalize_input: state['running_mean_std'] = self.running_mean_std.state_dict() if self.normalize_value: state['reward_mean_std'] = self.value_mean_std.state_dict() return state def get_stats_weights(self): state = {} if self.normalize_input: state['running_mean_std'] = self.running_mean_std.state_dict() if self.normalize_value: state['reward_mean_std'] = self.value_mean_std.state_dict() if self.has_central_value: state['assymetric_vf_mean_std'] = self.central_value_net.get_stats_weights() return state def set_stats_weights(self, weights): if self.normalize_input: self.running_mean_std.load_state_dict(weights['running_mean_std']) if self.normalize_value: self.value_mean_std.load_state_dict(weights['reward_mean_std']) if self.has_central_value: self.central_value_net.set_stats_weights(state['assymetric_vf_mean_std']) def set_weights(self, weights): self.model.load_state_dict(weights['model']) if self.normalize_input: self.running_mean_std.load_state_dict(weights['running_mean_std']) if self.normalize_value: self.value_mean_std.load_state_dict(weights['reward_mean_std']) def _preproc_obs(self, obs_batch): if obs_batch.dtype == torch.uint8: obs_batch = obs_batch.float() / 255.0 #if len(obs_batch.size()) == 3: # obs_batch = obs_batch.permute((0, 2, 1)) if len(obs_batch.size()) == 4: obs_batch = obs_batch.permute((0, 3, 1, 2)) if self.normalize_input: obs_batch = self.running_mean_std(obs_batch) return obs_batch def play_steps(self): mb_rnn_states = [] epinfos = [] mb_obs = self.mb_obs mb_rewards = self.mb_rewards mb_values = self.mb_values mb_dones = self.mb_dones tensors_dict = self.tensors_dict update_list = self.update_list update_dict = self.update_dict if self.has_central_value: mb_vobs = self.mb_vobs batch_size = self.num_agents * self.num_actors mb_rnn_masks = None for n in range(self.steps_num): if self.use_action_masks: masks = self.vec_env.get_action_masks() res_dict = self.get_masked_action_values(self.obs, masks) else: res_dict = self.get_action_values(self.obs) mb_obs[n,:] = self.obs['obs'] mb_dones[n,:] = self.dones for k in update_list: tensors_dict[k][n,:] = res_dict[k] if self.has_central_value: mb_vobs[n,:] = self.obs['states'] self.obs, rewards, self.dones, infos = self.env_step(res_dict['action']) shaped_rewards = self.rewards_shaper(rewards) mb_rewards[n,:] = shaped_rewards self.current_rewards += rewards self.current_lengths += 1 all_done_indices = self.dones.nonzero(as_tuple=False) done_indices = all_done_indices[::self.num_agents] self.game_rewards.update(self.current_rewards[done_indices]) self.game_lengths.update(self.current_lengths[done_indices]) self.algo_observer.process_infos(infos, done_indices) not_dones = 1.0 - self.dones.float() self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) self.current_lengths = self.current_lengths * not_dones if self.has_central_value and self.central_value_net.use_joint_obs_actions: if self.use_action_masks: masks = self.vec_env.get_action_masks() val_dict = self.get_masked_action_values(self.obs, masks) else: val_dict = self.get_action_values(self.obs) last_values = val_dict['value'] else: last_values = self.get_values(self.obs) mb_extrinsic_values = mb_values last_extrinsic_values = last_values fdones = self.dones.float() mb_fdones = mb_dones.float() mb_advs = self.discount_values(fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards) mb_returns = mb_advs + mb_extrinsic_values batch_dict = { 'obs' : mb_obs, 'returns' : mb_returns, 'dones' : mb_dones, } for k in update_list: batch_dict[update_dict[k]] = tensors_dict[k] if self.has_central_value: batch_dict['states'] = mb_vobs batch_dict = {k: swap_and_flatten01(v) for k, v in batch_dict.items()} return batch_dict def play_steps_rnn(self): mb_rnn_states = [] epinfos = [] mb_obs = self.mb_obs mb_values = self.mb_values.fill_(0) mb_rewards = self.mb_rewards.fill_(0) mb_dones = self.mb_dones.fill_(1) tensors_dict = self.tensors_dict update_list = self.update_list update_dict = self.update_dict if self.has_central_value: mb_vobs = self.mb_vobs batch_size = self.num_agents * self.num_actors mb_rnn_masks = None mb_rnn_masks, indices, steps_mask, steps_state, play_mask, mb_rnn_states = self.init_rnn_step(batch_size, mb_rnn_states) for n in range(self.steps_num): seq_indices, full_tensor = self.process_rnn_indices(mb_rnn_masks, indices, steps_mask, steps_state, mb_rnn_states) if full_tensor: break if self.has_central_value: self.central_value_net.pre_step_rnn(self.last_rnn_indices, self.last_state_indices) if self.use_action_masks: masks = self.vec_env.get_action_masks() res_dict = self.get_masked_action_values(self.obs, masks) else: res_dict = self.get_action_values(self.obs) self.rnn_states = res_dict['rnn_state'] mb_dones[indices, play_mask] = self.dones.byte() mb_obs[indices,play_mask] = self.obs['obs'] for k in update_list: tensors_dict[k][indices,play_mask] = res_dict[k] if self.has_central_value: mb_vobs[indices[::self.num_agents] ,play_mask[::self.num_agents]//self.num_agents] = self.obs['states'] self.obs, rewards, self.dones, infos = self.env_step(res_dict['action']) shaped_rewards = self.rewards_shaper(rewards) mb_rewards[indices, play_mask] = shaped_rewards self.current_rewards += rewards self.current_lengths += 1 all_done_indices = self.dones.nonzero(as_tuple=False) done_indices = all_done_indices[::self.num_agents] self.process_rnn_dones(all_done_indices, indices, seq_indices) if self.has_central_value: self.central_value_net.post_step_rnn(all_done_indices) self.algo_observer.process_infos(infos, done_indices) fdones = self.dones.float() not_dones = 1.0 - self.dones.float() self.game_rewards.update(self.current_rewards[done_indices]) self.game_lengths.update(self.current_lengths[done_indices]) self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) self.current_lengths = self.current_lengths * not_dones if self.has_central_value and self.central_value_net.use_joint_obs_actions: if self.use_action_masks: masks = self.vec_env.get_action_masks() val_dict = self.get_masked_action_values(self.obs, masks) else: val_dict = self.get_action_values(self.obs) last_values = val_dict['value'] else: last_values = self.get_values(self.obs) mb_extrinsic_values = mb_values last_extrinsic_values = last_values fdones = self.dones.float() mb_fdones = mb_dones.float() non_finished = (indices != self.steps_num).nonzero(as_tuple=False) ind_to_fill = indices[non_finished] mb_fdones[ind_to_fill,non_finished] = fdones[non_finished] mb_extrinsic_values[ind_to_fill,non_finished] = last_extrinsic_values[non_finished] fdones[non_finished] = 1.0 last_extrinsic_values[non_finished] = 0 mb_advs = self.discount_values_masks(fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_rnn_masks.view(-1,self.steps_num).transpose(0,1)) mb_returns = mb_advs + mb_extrinsic_values batch_dict = { 'obs' : mb_obs, 'returns' : mb_returns, 'dones' : mb_dones, } for k in update_list: batch_dict[update_dict[k]] = tensors_dict[k] if self.has_central_value: batch_dict['states'] = mb_vobs batch_dict = {k: swap_and_flatten01(v) for k, v in batch_dict.items()} batch_dict['rnn_states'] = mb_rnn_states batch_dict['rnn_masks'] = mb_rnn_masks return batch_dict