Esempio n. 1
0
class PpoPlayerContinuous(BasePlayer):
    def __init__(self, config):
        BasePlayer.__init__(self, config)
        self.network = config['network']
        self.actions_num = self.action_space.shape[0]
        self.actions_low = torch.from_numpy(
            self.action_space.low.copy()).float().to(self.device)
        self.actions_high = torch.from_numpy(
            self.action_space.high.copy()).float().to(self.device)
        self.mask = [False]

        self.normalize_input = self.config['normalize_input']
        obs_shape = self.obs_shape
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_agents
        }
        self.model = self.network.build(config)
        self.model.to(self.device)
        self.model.eval()
        self.is_rnn = self.model.is_rnn()
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)
            self.running_mean_std.eval()

    def get_action(self, obs, is_determenistic=False):
        if self.has_batch_dimension == False:
            obs = unsqueeze_obs(obs)
        obs = self._preproc_obs(obs)
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'obs': obs,
            'rnn_states': self.states
        }
        with torch.no_grad():
            res_dict = self.model(input_dict)
        mu = res_dict['mus']
        action = res_dict['actions']
        self.states = res_dict['rnn_states']
        if is_determenistic:
            current_action = mu
        else:
            current_action = action
        current_action = torch.squeeze(current_action.detach())
        return rescale_actions(self.actions_low, self.actions_high,
                               torch.clamp(current_action, -1.0, 1.0))

    def restore(self, fn):
        checkpoint = torch_ext.load_checkpoint(fn)
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(
                checkpoint['running_mean_std'])

    def reset(self):
        self.init_rnn()
Esempio n. 2
0
class PpoPlayerDiscrete(BasePlayer):
    def __init__(self, config):
        BasePlayer.__init__(self, config)

        self.network = config['network']
        if type(self.action_space) is gym.spaces.Discrete:
            self.actions_num = self.action_space.n
            self.is_multi_discrete = False
        if type(self.action_space) is gym.spaces.Tuple:
            self.actions_num = [action.n for action in self.action_space]
            self.is_multi_discrete = True
        self.mask = [False]

        self.normalize_input = self.config['normalize_input']

        obs_shape = self.obs_shape
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_agents,
            'value_size': self.value_size
        }

        self.model = self.network.build(config)
        self.model.to(self.device)
        self.model.eval()
        self.is_rnn = self.model.is_rnn()
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)
            self.running_mean_std.eval()

    def get_masked_action(self, obs, action_masks, is_determenistic=True):
        if self.has_batch_dimension == False:
            obs = unsqueeze_obs(obs)
        obs = self._preproc_obs(obs)
        action_masks = torch.Tensor(action_masks).to(self.device)
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'obs': obs,
            'action_masks': action_masks,
            'rnn_states': self.states
        }
        self.model.eval()

        with torch.no_grad():
            neglogp, value, action, logits, self.states = self.model(
                input_dict)
        logits = res_dict['logits']
        action = res_dict['actions']
        self.states = res_dict['rnn_states']
        if self.is_multi_discrete:
            if is_determenistic:
                action = [
                    torch.argmax(logit.detach(), axis=-1).squeeze()
                    for logit in logits
                ]
                return torch.stack(action, dim=-1)
            else:
                return action.squeeze().detach()
        else:
            if is_determenistic:
                return torch.argmax(logits.detach(), axis=-1).squeeze()
            else:
                return action.squeeze().detach()

    def get_action(self, obs, is_determenistic=False):
        if self.has_batch_dimension == False:
            obs = unsqueeze_obs(obs)
        obs = self._preproc_obs(obs)
        self.model.eval()
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'obs': obs,
            'rnn_states': self.states
        }
        with torch.no_grad():
            res_dict = self.model(input_dict)
        logits = res_dict['logits']
        action = res_dict['actions']
        self.states = res_dict['rnn_states']
        if self.is_multi_discrete:
            if is_determenistic:
                action = [
                    torch.argmax(logit.detach(), axis=1).squeeze()
                    for logit in logits
                ]
                return torch.stack(action, dim=-1)
            else:
                return action.squeeze().detach()
        else:
            if is_determenistic:
                return torch.argmax(logits.detach(), axis=-1).squeeze()
            else:
                return action.squeeze().detach()

    def restore(self, fn):
        checkpoint = torch_ext.load_checkpoint(fn)
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(
                checkpoint['running_mean_std'])

    def reset(self):
        self.init_rnn()
Esempio n. 3
0
class A2CBase:
    def __init__(self, base_name, config):
        self.config = config
        self.env_config = config.get('env_config', {})
        self.num_actors = config['num_actors']
        self.env_name = config['env_name']

        self.env_info = config.get('env_info')
        if self.env_info is None:
            self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
            self.env_info = self.vec_env.get_env_info()

        self.ppo_device = config.get('device', 'cuda:0')
        print('Env info:')
        print(self.env_info)
        self.value_size = self.env_info.get('value_size',1)
        self.observation_space = self.env_info['observation_space']
        self.weight_decay = config.get('weight_decay', 0.0)
        self.use_action_masks = config.get('use_action_masks', False)
        self.is_train = config.get('is_train', True)

        self.central_value_config = self.config.get('central_value_config', None)
        self.has_central_value = self.central_value_config is not None

        if self.has_central_value:
            self.state_space = self.env_info.get('state_space', None)
            self.state_shape = None
            if self.state_space.shape != None:
                self.state_shape = self.state_space.shape

        self.self_play_config = self.config.get('self_play_config', None)
        self.has_self_play_config = self.self_play_config is not None

        self.self_play = config.get('self_play', False)
        self.save_freq = config.get('save_frequency', 0)
        self.save_best_after = config.get('save_best_after', 100)
        self.print_stats = config.get('print_stats', True)
        self.rnn_states = None
        self.name = base_name

        self.ppo = config['ppo']
        self.max_epochs = self.config.get('max_epochs', 1e6)

        self.is_adaptive_lr = config['lr_schedule'] == 'adaptive'
        self.linear_lr = config['lr_schedule'] == 'linear'
        self.schedule_type = config.get('schedule_type', 'legacy')
        if self.is_adaptive_lr:
            self.lr_threshold = config['lr_threshold']
            self.scheduler = schedulers.AdaptiveScheduler(self.lr_threshold)
        elif self.linear_lr:
            self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']), 
                max_steps=self.max_epochs, 
                apply_to_entropy=config.get('schedule_entropy', False),
                start_entropy_coef=config.get('entropy_coef'))
        else:
            self.scheduler = schedulers.IdentityScheduler()

        self.e_clip = config['e_clip']
        self.clip_value = config['clip_value']
        self.network = config['network']
        self.rewards_shaper = config['reward_shaper']
        self.num_agents = self.env_info.get('agents', 1)
        self.steps_num = config['steps_num']
        self.seq_len = self.config.get('seq_length', 4)
        self.normalize_advantage = config['normalize_advantage']
        self.normalize_input = self.config['normalize_input']
        self.normalize_value = self.config.get('normalize_value', False)

        self.obs_shape = self.observation_space.shape
 
        self.critic_coef = config['critic_coef']
        self.grad_norm = config['grad_norm']
        self.gamma = self.config['gamma']
        self.tau = self.config['tau']

        self.games_to_track = self.config.get('games_to_track', 100)
        self.game_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device)
        self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
        self.obs = None
        self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation
        self.batch_size = self.steps_num * self.num_actors * self.num_agents
        self.batch_size_envs = self.steps_num * self.num_actors
        self.minibatch_size = self.config['minibatch_size']
        self.mini_epochs_num = self.config['mini_epochs']
        self.num_minibatches = self.batch_size // self.minibatch_size
        assert(self.batch_size % self.minibatch_size == 0)

        self.last_lr = self.config['learning_rate']
        self.frame = 0
        self.update_time = 0
        self.last_mean_rewards = -100500
        self.play_time = 0
        self.epoch_num = 0
        
        self.entropy_coef = self.config['entropy_coef']
        self.writer = SummaryWriter('runs/' + config['name'] + datetime.now().strftime("_%d-%H-%M-%S"))

        if self.normalize_value:
            self.value_mean_std = RunningMeanStd((1,)).to(self.ppo_device)

        self.is_tensor_obses = False

        self.last_rnn_indices = None
        self.last_state_indices = None

        #self_play
        if self.has_self_play_config:
            print('Initializing SelfPlay Manager')
            self.self_play_manager = SelfPlayManager(self.self_play_config, self.writer)
        
        # features
        self.algo_observer = config['features']['observer']

    def set_eval(self):
        self.model.eval()
        if self.normalize_input:
            self.running_mean_std.eval()
        if self.normalize_value:
            value = self.value_mean_std.eval()

    def set_train(self):
        self.model.train()
        if self.normalize_input:
            self.running_mean_std.train()
        if self.normalize_value:
            value = self.value_mean_std.train()

    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_action_values(self, obs):
        processed_obs = self._preproc_obs(obs['obs'])
        self.model.eval()
        input_dict = {
            'is_train': False,
            'prev_actions': None, 
            'obs' : processed_obs,
            'rnn_states' : self.rnn_states
        }

        with torch.no_grad():
            res_dict = self.model(input_dict)
            if self.has_central_value:
                states = obs['states']
                input_dict = {
                    'is_train': False,
                    'states' : states,
                    #'actions' : res_dict['action'],
                    #'rnn_states' : self.rnn_states
                }
                value = self.get_central_value(input_dict)
                res_dict['value'] = value
        if self.normalize_value:
            res_dict['value'] = self.value_mean_std(res_dict['value'], True)
        return res_dict

    def get_values(self, obs):
        with torch.no_grad():
            if self.has_central_value:
                states = obs['states']
                self.central_value_net.eval()
                input_dict = {
                    'is_train': False,
                    'states' : states,
                    'actions' : None,
                    'is_done': self.dones,
                }
                value = self.get_central_value(input_dict)
            else:
                self.model.eval()
                processed_obs = self._preproc_obs(obs['obs'])
                input_dict = {
                    'is_train': False,
                    'prev_actions': None, 
                    'obs' : processed_obs,
                    'rnn_states' : self.rnn_states
                }
                
                result = self.model(input_dict)
                value = result['value']

            if self.normalize_value:
                value = self.value_mean_std(value, True)
            return value

    def reset_envs(self):
        self.obs = self.env_reset()

    def init_tensors(self):
        if self.observation_space.dtype == np.uint8:
            torch_dtype = torch.uint8
        else:
            torch_dtype = torch.float32
        batch_size = self.num_agents * self.num_actors
 
        val_shape = (self.steps_num, batch_size, self.value_size)
        current_rewards_shape = (batch_size, self.value_size)
        self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
        self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device)
        self.dones = torch.zeros((batch_size,), dtype=torch.uint8, device=self.ppo_device)
        self.mb_obs = torch.zeros((self.steps_num, batch_size) + self.obs_shape, dtype=torch_dtype, device=self.ppo_device)

        if self.has_central_value:
            self.mb_vobs = torch.zeros((self.steps_num, self.num_actors) + self.state_shape, dtype=torch_dtype, device=self.ppo_device)

        self.mb_rewards = torch.zeros(val_shape, dtype = torch.float32, device=self.ppo_device)
        self.mb_values = torch.zeros(val_shape, dtype = torch.float32, device=self.ppo_device)
        self.mb_dones = torch.zeros((self.steps_num, batch_size), dtype = torch.uint8, device=self.ppo_device)
        self.mb_neglogpacs = torch.zeros((self.steps_num, batch_size), dtype = torch.float32, device=self.ppo_device)

        if self.is_rnn:
            self.rnn_states = self.model.get_default_rnn_state()
            self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]

            batch_size = self.num_agents * self.num_actors
            num_seqs = self.steps_num * batch_size // self.seq_len
            assert((self.steps_num * batch_size // self.num_minibatches) % self.seq_len == 0)
            self.mb_rnn_states = [torch.zeros((s.size()[0], num_seqs, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

    def init_rnn_from_model(self, model):
        self.is_rnn = self.model.is_rnn()

    def init_rnn_step(self, batch_size, mb_rnn_states):
        mb_rnn_states = self.mb_rnn_states
        mb_rnn_masks = torch.zeros(self.steps_num*batch_size, dtype = torch.float32, device=self.ppo_device)
        steps_mask = torch.arange(0, batch_size * self.steps_num, self.steps_num, dtype=torch.long, device=self.ppo_device)
        play_mask = torch.arange(0, batch_size, 1, dtype=torch.long, device=self.ppo_device)
        steps_state = torch.arange(0, batch_size * self.steps_num//self.seq_len, self.steps_num//self.seq_len, dtype=torch.long, device=self.ppo_device)
        indices = torch.zeros((batch_size), dtype = torch.long, device=self.ppo_device)
        return mb_rnn_masks, indices, steps_mask, steps_state, play_mask, mb_rnn_states

    def process_rnn_indices(self, mb_rnn_masks, indices, steps_mask, steps_state, mb_rnn_states):
        seq_indices = None
        if indices.max().item() >= self.steps_num:
            return seq_indices, True

        mb_rnn_masks[indices + steps_mask] = 1
        seq_indices = indices % self.seq_len
        state_indices = (seq_indices == 0).nonzero(as_tuple=False)
        state_pos = indices // self.seq_len
        rnn_indices = state_pos[state_indices] + steps_state[state_indices]

        for s, mb_s in zip(self.rnn_states, mb_rnn_states):
            mb_s[:, rnn_indices, :] = s[:, state_indices, :]

        self.last_rnn_indices = rnn_indices
        self.last_state_indices = state_indices
        return seq_indices, False

    def process_rnn_dones(self, all_done_indices, indices, seq_indices):
        if len(all_done_indices) > 0:
            shifts = self.seq_len - 1 - seq_indices[all_done_indices]
            indices[all_done_indices] += shifts
            for s in self.rnn_states:
                s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0
        indices += 1  


    def cast_obs(self, obs):
        if isinstance(obs, torch.Tensor):
            self.is_tensor_obses = True
        elif isinstance(obs, np.ndarray):
            assert(self.observation_space.dtype != np.int8)
            if self.observation_space.dtype == np.uint8:
                obs = torch.ByteTensor(obs).to(self.ppo_device)
            else:
                obs = torch.FloatTensor(obs).to(self.ppo_device)
        return obs

    def obs_to_tensors(self, obs):
        if isinstance(obs, dict):
            upd_obs = {}
            for key, value in obs.items():
                upd_obs[key] = self.cast_obs(value)
        else:
            upd_obs = {'obs' : self.cast_obs(obs)}
        return upd_obs

    def preprocess_actions(self, actions):
        if not self.is_tensor_obses:
            actions = actions.cpu().numpy()
        return actions

    def env_step(self, actions):
        actions = self.preprocess_actions(actions)
        obs, rewards, dones, infos = self.vec_env.step(actions)

        if self.is_tensor_obses:
            if self.value_size == 1:
                rewards = rewards.unsqueeze(1)
            return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos
        else:
            if self.value_size == 1:
                rewards = np.expand_dims(rewards, axis=1)
            return self.obs_to_tensors(obs), torch.from_numpy(rewards).to(self.ppo_device).float(), torch.from_numpy(dones).to(self.ppo_device), infos

    def env_reset(self):
        obs = self.vec_env.reset() 
        obs = self.obs_to_tensors(obs)
        return obs

    def discount_values(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards):
        lastgaelam = 0
        mb_advs = torch.zeros_like(mb_rewards)
        for t in reversed(range(self.steps_num)):
            if t == self.steps_num - 1:
                nextnonterminal = 1.0 - fdones
                nextvalues = last_extrinsic_values
            else:
                nextnonterminal = 1.0 - mb_fdones[t+1]
                nextvalues = mb_extrinsic_values[t+1]
            nextnonterminal = nextnonterminal.unsqueeze(1)
            delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal  - mb_extrinsic_values[t]
            mb_advs[t] = lastgaelam = delta + self.gamma * self.tau * nextnonterminal * lastgaelam
        return mb_advs

    def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_masks):
        lastgaelam = 0
        mb_advs = torch.zeros_like(mb_rewards)
        for t in reversed(range(self.steps_num)):
            if t == self.steps_num - 1:
                nextnonterminal = 1.0 - fdones
                nextvalues = last_extrinsic_values
            else:
                nextnonterminal = 1.0 - mb_fdones[t+1]
                nextvalues = mb_extrinsic_values[t+1]
            nextnonterminal = nextnonterminal.unsqueeze(1)
            delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal  - mb_extrinsic_values[t]
            mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * mb_masks[t].unsqueeze(1)
        return mb_advs

    def clear_stats(self):
        batch_size = self.num_agents * self.num_actors
        self.game_rewards.clear()
        self.game_lengths.clear()
        self.last_mean_rewards = -100500
        self.algo_observer.after_clear_stats()

    def update_epoch(self):
        pass

    def train(self):       
        pass

    def prepare_dataset(self, batch_dict):
        pass

    def train_epoch(self):
        pass

    def train_actor_critic(self, obs_dict, opt_step=True):
        pass 

    def calc_gradients(self, opt_step):
        pass

    def get_central_value(self, obs_dict):
        return self.central_value_net.get_value(obs_dict)

    def train_central_value(self):
        return self.central_value_net.train_net()

    def get_full_state_weights(self):
        state = self.get_weights()

        state['epoch'] = self.epoch_num
        state['optimizer'] = self.optimizer.state_dict()      

        if self.has_central_value:
            state['assymetric_vf_nets'] = self.central_value_net.state_dict()
        return state

    def set_full_state_weights(self, weights):
        self.set_weights(weights)

        self.epoch_num = weights['epoch']
        if self.has_central_value:
            self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])
        self.optimizer.load_state_dict(weights['optimizer'])

    def get_weights(self):
        state = {'model': self.model.state_dict()}

        if self.normalize_input:
            state['running_mean_std'] = self.running_mean_std.state_dict()
        if self.normalize_value:
            state['reward_mean_std'] = self.value_mean_std.state_dict()   
        return state

    def get_stats_weights(self):
        state = {}
        if self.normalize_input:
            state['running_mean_std'] = self.running_mean_std.state_dict()
        if self.normalize_value:
            state['reward_mean_std'] = self.value_mean_std.state_dict()
        if self.has_central_value:
            state['assymetric_vf_mean_std'] = self.central_value_net.get_stats_weights()
        return state

    def set_stats_weights(self, weights):
        if self.normalize_input:
            self.running_mean_std.load_state_dict(weights['running_mean_std'])
        if self.normalize_value:
            self.value_mean_std.load_state_dict(weights['reward_mean_std'])
        if self.has_central_value:
            self.central_value_net.set_stats_weights(state['assymetric_vf_mean_std'])
  
    def set_weights(self, weights):
        self.model.load_state_dict(weights['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(weights['running_mean_std'])
        if self.normalize_value:
            self.value_mean_std.load_state_dict(weights['reward_mean_std'])

    def _preproc_obs(self, obs_batch):
        if obs_batch.dtype == torch.uint8:
            obs_batch = obs_batch.float() / 255.0
        #if len(obs_batch.size()) == 3:
        #    obs_batch = obs_batch.permute((0, 2, 1))
        if len(obs_batch.size()) == 4:
            obs_batch = obs_batch.permute((0, 3, 1, 2))
        if self.normalize_input:
            obs_batch = self.running_mean_std(obs_batch)
        return obs_batch

    def play_steps(self):
        mb_rnn_states = []
        epinfos = []

        mb_obs = self.mb_obs
        mb_rewards = self.mb_rewards
        mb_values = self.mb_values
        mb_dones = self.mb_dones
        
        tensors_dict = self.tensors_dict
        update_list = self.update_list
        update_dict = self.update_dict

        if self.has_central_value:
            mb_vobs = self.mb_vobs

        batch_size = self.num_agents * self.num_actors
        mb_rnn_masks = None

        for n in range(self.steps_num):
            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                res_dict = self.get_masked_action_values(self.obs, masks)
            else:
                res_dict = self.get_action_values(self.obs)

            mb_obs[n,:] = self.obs['obs']
            mb_dones[n,:] = self.dones
            for k in update_list:
                tensors_dict[k][n,:] = res_dict[k]

            if self.has_central_value:
                mb_vobs[n,:] = self.obs['states']

            self.obs, rewards, self.dones, infos = self.env_step(res_dict['action'])

            shaped_rewards = self.rewards_shaper(rewards)
            mb_rewards[n,:] = shaped_rewards

            self.current_rewards += rewards
            self.current_lengths += 1
            all_done_indices = self.dones.nonzero(as_tuple=False)
            done_indices = all_done_indices[::self.num_agents]
  
            self.game_rewards.update(self.current_rewards[done_indices])
            self.game_lengths.update(self.current_lengths[done_indices])

            self.algo_observer.process_infos(infos, done_indices)

            not_dones = 1.0 - self.dones.float()

            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones
        
        if self.has_central_value and self.central_value_net.use_joint_obs_actions:
            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                val_dict = self.get_masked_action_values(self.obs, masks)
            else:
                val_dict = self.get_action_values(self.obs)
            last_values = val_dict['value']
        else:
            last_values = self.get_values(self.obs)

        mb_extrinsic_values = mb_values
        last_extrinsic_values = last_values

        fdones = self.dones.float()
        mb_fdones = mb_dones.float()
        mb_advs = self.discount_values(fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards)
        mb_returns = mb_advs + mb_extrinsic_values
        batch_dict = {
            'obs' : mb_obs,
            'returns' : mb_returns,
            'dones' : mb_dones,
        }
        for k in update_list:
            batch_dict[update_dict[k]] = tensors_dict[k]

        if self.has_central_value:
            batch_dict['states'] = mb_vobs

        batch_dict = {k: swap_and_flatten01(v) for k, v in batch_dict.items()}

        return batch_dict

    def play_steps_rnn(self):
        mb_rnn_states = []
        epinfos = []

        mb_obs = self.mb_obs
        mb_values = self.mb_values.fill_(0)
        mb_rewards = self.mb_rewards.fill_(0)
        mb_dones = self.mb_dones.fill_(1)

        tensors_dict = self.tensors_dict
        update_list = self.update_list
        update_dict = self.update_dict

        if self.has_central_value:
            mb_vobs = self.mb_vobs

        batch_size = self.num_agents * self.num_actors
        mb_rnn_masks = None

        mb_rnn_masks, indices, steps_mask, steps_state, play_mask, mb_rnn_states = self.init_rnn_step(batch_size, mb_rnn_states)

        for n in range(self.steps_num):
            seq_indices, full_tensor = self.process_rnn_indices(mb_rnn_masks, indices, steps_mask, steps_state, mb_rnn_states)
            if full_tensor:
                break
            if self.has_central_value:
                self.central_value_net.pre_step_rnn(self.last_rnn_indices, self.last_state_indices)

            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                res_dict = self.get_masked_action_values(self.obs, masks)
            else:
                res_dict = self.get_action_values(self.obs)
                
            self.rnn_states = res_dict['rnn_state']

            mb_dones[indices, play_mask] = self.dones.byte()
            mb_obs[indices,play_mask] = self.obs['obs']   

            for k in update_list:
                tensors_dict[k][indices,play_mask] = res_dict[k]

            if self.has_central_value:
                mb_vobs[indices[::self.num_agents] ,play_mask[::self.num_agents]//self.num_agents] = self.obs['states']

            self.obs, rewards, self.dones, infos = self.env_step(res_dict['action'])

            shaped_rewards = self.rewards_shaper(rewards)

            mb_rewards[indices, play_mask] = shaped_rewards

            self.current_rewards += rewards
            self.current_lengths += 1
            all_done_indices = self.dones.nonzero(as_tuple=False)
            done_indices = all_done_indices[::self.num_agents]

            self.process_rnn_dones(all_done_indices, indices, seq_indices)  
            if self.has_central_value:
                self.central_value_net.post_step_rnn(all_done_indices)
        
            self.algo_observer.process_infos(infos, done_indices)

            fdones = self.dones.float()
            not_dones = 1.0 - self.dones.float()

            self.game_rewards.update(self.current_rewards[done_indices])
            self.game_lengths.update(self.current_lengths[done_indices])
            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones

        if self.has_central_value and self.central_value_net.use_joint_obs_actions:
            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                val_dict = self.get_masked_action_values(self.obs, masks)
            else:
                val_dict = self.get_action_values(self.obs)
            
            last_values = val_dict['value']
        else:
            last_values = self.get_values(self.obs)

        mb_extrinsic_values = mb_values
        last_extrinsic_values = last_values
        fdones = self.dones.float()
        mb_fdones = mb_dones.float()

        non_finished = (indices != self.steps_num).nonzero(as_tuple=False)
        ind_to_fill = indices[non_finished]
        mb_fdones[ind_to_fill,non_finished] = fdones[non_finished]
        mb_extrinsic_values[ind_to_fill,non_finished] = last_extrinsic_values[non_finished]
        fdones[non_finished] = 1.0
        last_extrinsic_values[non_finished] = 0

        mb_advs = self.discount_values_masks(fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_rnn_masks.view(-1,self.steps_num).transpose(0,1))

        mb_returns = mb_advs + mb_extrinsic_values

        batch_dict = {
            'obs' : mb_obs,
            'returns' : mb_returns,
            'dones' : mb_dones,
        }
        for k in update_list:
            batch_dict[update_dict[k]] = tensors_dict[k]

        if self.has_central_value:
            batch_dict['states'] = mb_vobs

        batch_dict = {k: swap_and_flatten01(v) for k, v in batch_dict.items()}

        batch_dict['rnn_states'] = mb_rnn_states
        batch_dict['rnn_masks'] = mb_rnn_masks

        return batch_dict