Beispiel #1
0
    def __init__(self, config):
        BasePlayer.__init__(self, config)

        self.network = config['network']
        if type(self.action_space) is gym.spaces.Discrete:
            self.actions_num = self.action_space.n
            self.is_multi_discrete = False
        if type(self.action_space) is gym.spaces.Tuple:
            self.actions_num = [action.n for action in self.action_space]
            self.is_multi_discrete = True
        self.mask = [False]

        self.normalize_input = self.config['normalize_input']

        obs_shape = self.obs_shape
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_agents,
            'value_size': self.value_size
        }

        self.model = self.network.build(config)
        self.model.to(self.device)
        self.model.eval()
        self.is_rnn = self.model.is_rnn()
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)
            self.running_mean_std.eval()
Beispiel #2
0
class PpoPlayerContinuous(BasePlayer):
    def __init__(self, config):
        BasePlayer.__init__(self, config)
        self.network = config['network']
        self.actions_num = self.action_space.shape[0]
        self.actions_low = torch.from_numpy(
            self.action_space.low.copy()).float().to(self.device)
        self.actions_high = torch.from_numpy(
            self.action_space.high.copy()).float().to(self.device)
        self.mask = [False]

        self.normalize_input = self.config['normalize_input']
        obs_shape = self.obs_shape
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_agents
        }
        self.model = self.network.build(config)
        self.model.to(self.device)
        self.model.eval()
        self.is_rnn = self.model.is_rnn()
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)
            self.running_mean_std.eval()

    def get_action(self, obs, is_determenistic=False):
        if self.has_batch_dimension == False:
            obs = unsqueeze_obs(obs)
        obs = self._preproc_obs(obs)
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'obs': obs,
            'rnn_states': self.states
        }
        with torch.no_grad():
            res_dict = self.model(input_dict)
        mu = res_dict['mus']
        action = res_dict['actions']
        self.states = res_dict['rnn_states']
        if is_determenistic:
            current_action = mu
        else:
            current_action = action
        current_action = torch.squeeze(current_action.detach())
        return rescale_actions(self.actions_low, self.actions_high,
                               torch.clamp(current_action, -1.0, 1.0))

    def restore(self, fn):
        checkpoint = torch_ext.load_checkpoint(fn)
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(
                checkpoint['running_mean_std'])

    def reset(self):
        self.init_rnn()
Beispiel #3
0
    def __init__(self, state_shape, value_size, ppo_device, num_agents, num_steps, num_actors, num_actions, seq_len, model, config, writter, multi_gpu):
        nn.Module.__init__(self)
        self.ppo_device = ppo_device
        self.num_agents, self.num_steps, self.num_actors, self.seq_len = num_agents, num_steps, num_actors, seq_len
        self.num_actions = num_actions
        self.state_shape = state_shape
        self.value_size = value_size
        self.multi_gpu = multi_gpu
        self.truncate_grads = config.get('truncate_grads', False)
        state_config = {
            'value_size' : value_size,
            'input_shape' : state_shape,
            'actions_num' : num_actions,
            'num_agents' : num_agents,
            'num_seqs' : num_actors
        }

        self.config = config
        self.model = model.build('cvalue', **state_config)
        self.lr = config['lr']
        self.mini_epoch = config['mini_epochs']
        self.mini_batch = config['minibatch_size']
        self.num_minibatches = self.num_steps * self.num_actors // self.mini_batch
        self.clip_value = config['clip_value']
        self.normalize_input = config['normalize_input']
        self.writter = writter
        self.use_joint_obs_actions = config.get('use_joint_obs_actions', False)
        self.weight_decay = config.get('weight_decay', 0.0)
        self.optimizer = torch.optim.Adam(self.model.parameters(), float(self.lr), eps=1e-08, weight_decay=self.weight_decay)
        self.frame = 0
        self.running_mean_std = None
        self.grad_norm = config.get('grad_norm', 1)
        self.truncate_grads = config.get('truncate_grads', False)
        self.e_clip = config.get('e_clip', 0.2)
        self.truncate_grad = self.config.get('truncate_grads', False)
        
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(state_shape)

        self.is_rnn = self.model.is_rnn()
        self.rnn_states = None
        self.batch_size = self.num_steps * self.num_actors
        if self.is_rnn:
            self.rnn_states = self.model.get_default_rnn_state()
            self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
            num_seqs = self.num_steps * self.num_actors // self.seq_len
            assert((self.num_steps * self.num_actors // self.num_minibatches) % self.seq_len == 0)
            self.mb_rnn_states = [torch.zeros((s.size()[0], num_seqs, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

        self.dataset = datasets.PPODataset(self.batch_size, self.mini_batch, True, self.is_rnn, self.ppo_device, self.seq_len)
Beispiel #4
0
    def __init__(self, obs_shape, normalize_value, normalize_input,
                 value_size):
        nn.Module.__init__(self)
        self.obs_shape = obs_shape
        self.normalize_value = normalize_value
        self.normalize_input = normalize_input
        self.value_size = value_size

        if normalize_value:
            self.value_mean_std = RunningMeanStd((self.value_size, ))
        if normalize_input:
            if isinstance(obs_shape, dict):
                self.running_mean_std = RunningMeanStdObs(obs_shape)
            else:
                self.running_mean_std = RunningMeanStd(obs_shape)
Beispiel #5
0
    def __init__(self, base_name, config):
        a2c_common.DiscreteA2CBase.__init__(self, base_name, config)
        obs_shape = self.obs_shape

        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_actors * self.num_agents,
            'value_size': self.env_info.get('value_size', 1)
        }

        self.model = self.network.build(config)
        self.model.to(self.ppo_device)

        self.init_rnn_from_model(self.model)

        self.last_lr = float(self.last_lr)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    float(self.last_lr),
                                    eps=1e-08,
                                    weight_decay=self.weight_decay)

        if self.normalize_input:
            if isinstance(self.observation_space, gym.spaces.Dict):
                self.running_mean_std = RunningMeanStdObs(obs_shape).to(
                    self.ppo_device)
            else:
                self.running_mean_std = RunningMeanStd(obs_shape).to(
                    self.ppo_device)

        if self.has_central_value:
            cv_config = {
                'state_shape': self.state_shape,
                'value_size': self.value_size,
                'ppo_device': self.ppo_device,
                'num_agents': self.num_agents,
                'num_steps': self.steps_num,
                'num_actors': self.num_actors,
                'num_actions': self.actions_num,
                'seq_len': self.seq_len,
                'model': self.central_value_config['network'],
                'config': self.central_value_config,
                'writter': self.writer,
                'multi_gpu': self.multi_gpu
            }
            self.central_value_net = central_value.CentralValueTrain(
                **cv_config).to(self.ppo_device)

        self.use_experimental_cv = self.config.get('use_experimental_cv',
                                                   False)
        self.dataset = datasets.PPODataset(self.batch_size,
                                           self.minibatch_size,
                                           self.is_discrete, self.is_rnn,
                                           self.ppo_device, self.seq_len)
        self.algo_observer.after_init(self)
Beispiel #6
0
    def __init__(self, config):
        BasePlayer.__init__(self, config)
        self.network = config['network']
        self.actions_num = self.action_space.shape[0]
        self.actions_low = torch.from_numpy(
            self.action_space.low.copy()).float().to(self.device)
        self.actions_high = torch.from_numpy(
            self.action_space.high.copy()).float().to(self.device)
        self.mask = [False]

        self.normalize_input = self.config['normalize_input']
        obs_shape = self.obs_shape
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_agents
        }
        self.model = self.network.build(config)
        self.model.to(self.device)
        self.model.eval()
        self.is_rnn = self.model.is_rnn()
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)
            self.running_mean_std.eval()
Beispiel #7
0
    def __init__(self, base_name, config):
        a2c_common.ContinuousA2CBase.__init__(self, base_name, config)
        obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_actors * self.num_agents,
            'value_size': self.env_info.get('value_size', 1)
        }

        self.model = self.network.build(config)
        self.model.to(self.ppo_device)
        self.states = None

        self.init_rnn_from_model(self.model)
        self.last_lr = float(self.last_lr)

        self.optimizer = optim.Adam(self.model.parameters(),
                                    float(self.last_lr),
                                    eps=1e-07,
                                    weight_decay=self.weight_decay)

        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).to(
                self.ppo_device)

        if self.has_central_value:
            cv_config = {
                'state_shape': torch_ext.shape_whc_to_cwh(self.state_shape),
                'value_size': self.value_size,
                'ppo_device': self.ppo_device,
                'num_agents': self.num_agents,
                'num_steps': self.steps_num,
                'num_actors': self.num_actors,
                'num_actions': self.actions_num,
                'seq_len': self.seq_len,
                'model': self.central_value_config['network'],
                'config': self.central_value_config,
                'writter': self.writer
            }
            self.central_value_net = central_value.CentralValueTrain(
                **cv_config).to(self.ppo_device)
        self.use_experimental_cv = self.config.get('use_experimental_cv', True)
        self.dataset = datasets.PPODataset(self.batch_size,
                                           self.minibatch_size,
                                           self.is_discrete, self.is_rnn,
                                           self.ppo_device, self.seq_len)
        self.algo_observer.after_init(self)
Beispiel #8
0
class CentralValueTrain(nn.Module):
    def __init__(self, state_shape, value_size, ppo_device, num_agents,
                 num_steps, num_actors, num_actions, seq_len, model, config,
                 writter, multi_gpu):
        nn.Module.__init__(self)
        self.ppo_device = ppo_device
        self.num_agents, self.num_steps, self.num_actors, self.seq_len = num_agents, num_steps, num_actors, seq_len
        self.num_actions = num_actions
        self.state_shape = state_shape
        self.value_size = value_size
        self.multi_gpu = multi_gpu

        state_config = {
            'value_size': value_size,
            'input_shape': state_shape,
            'actions_num': num_actions,
            'num_agents': num_agents,
            'num_seqs': num_actors
        }

        self.config = config
        self.model = model.build('cvalue', **state_config)
        self.lr = config['lr']
        self.mini_epoch = config['mini_epochs']
        self.mini_batch = config['minibatch_size']
        self.num_minibatches = self.num_steps * self.num_actors // self.mini_batch
        self.clip_value = config['clip_value']
        self.normalize_input = config['normalize_input']
        self.writter = writter
        self.use_joint_obs_actions = config.get('use_joint_obs_actions', False)
        self.weight_decay = config.get('weight_decay', 0.0)
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          float(self.lr),
                                          eps=1e-08,
                                          weight_decay=self.weight_decay)
        self.frame = 0
        self.running_mean_std = None
        self.grad_norm = config.get('grad_norm', 1)
        self.truncate_grads = config.get('truncate_grads', False)
        self.e_clip = config.get('e_clip', 0.2)
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(state_shape)

        self.is_rnn = self.model.is_rnn()
        self.rnn_states = None
        self.batch_size = self.num_steps * self.num_actors
        if self.is_rnn:
            self.rnn_states = self.model.get_default_rnn_state()
            self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
            num_seqs = self.num_steps * self.num_actors // self.seq_len
            assert (
                (self.num_steps * self.num_actors // self.num_minibatches) %
                self.seq_len == 0)
            self.mb_rnn_states = [
                torch.zeros((s.size()[0], num_seqs, s.size()[2]),
                            dtype=torch.float32,
                            device=self.ppo_device) for s in self.rnn_states
            ]

        self.dataset = datasets.PPODataset(self.batch_size, self.mini_batch,
                                           True, self.is_rnn, self.ppo_device,
                                           self.seq_len)

    def update_lr(self, lr):
        '''
        if self.multi_gpu:
            lr_tensor = torch.tensor([lr])
            self.hvd.broadcast_value(lr_tensor, 'cv_learning_rate')
            lr = lr_tensor.item()
        '''
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_stats_weights(self):
        if self.normalize_input:
            return self.running_mean_std.state_dict()
        else:
            return {}

    def set_stats_weights(self, weights):
        self.running_mean_std.load_state_dict(weights)

    def update_dataset(self, batch_dict):
        value_preds = batch_dict['old_values']
        returns = batch_dict['returns']
        actions = batch_dict['actions']
        rnn_masks = batch_dict['rnn_masks']
        if self.num_agents > 1:
            res = self.update_multiagent_tensors(value_preds, returns, actions,
                                                 rnn_masks)
            batch_dict['old_values'] = res[0]
            batch_dict['returns'] = res[1]
            batch_dict['actions'] = res[2]

        if self.is_rnn:
            batch_dict['rnn_states'] = self.mb_rnn_states
            if self.num_agents > 1:
                rnn_masks = res[3]
            batch_dict['rnn_masks'] = rnn_masks
        self.dataset.update_values_dict(batch_dict)

    def _preproc_obs(self, obs_batch):
        if obs_batch.dtype == torch.uint8:
            obs_batch = obs_batch.float() / 255.0
        #if len(obs_batch.size()) == 3:
        #    obs_batch = obs_batch.permute((0, 2, 1))
        if len(obs_batch.size()) == 4:
            obs_batch = obs_batch.permute((0, 3, 1, 2))
        if self.normalize_input:
            obs_batch = self.running_mean_std(obs_batch)
        return obs_batch

    def pre_step_rnn(self, rnn_indices, state_indices):
        if self.num_agents > 1:
            rnn_indices = rnn_indices[::self.num_agents]
            shifts = rnn_indices % (self.num_steps // self.seq_len)
            rnn_indices = (rnn_indices - shifts) // self.num_agents + shifts
            state_indices = state_indices[::self.num_agents] // self.num_agents

        for s, mb_s in zip(self.rnn_states, self.mb_rnn_states):
            mb_s[:, rnn_indices, :] = s[:, state_indices, :]

    def post_step_rnn(self, all_done_indices):
        all_done_indices = all_done_indices[::self.
                                            num_agents] // self.num_agents
        for s in self.rnn_states:
            s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0

    def forward(self, input_dict):
        value, rnn_states = self.model(input_dict)
        return value, rnn_states

    def get_value(self, input_dict):
        self.eval()
        obs_batch = input_dict['states']

        actions = input_dict.get('actions', None)

        obs_batch = self._preproc_obs(obs_batch)
        value, self.rnn_states = self.forward({
            'obs': obs_batch,
            'actions': actions,
            'rnn_states': self.rnn_states
        })
        if self.num_agents > 1:
            value = value.repeat(1, self.num_agents)
            value = value.view(value.size()[0] * self.num_agents, -1)

        return value

    def train_critic(self, input_dict):
        self.train()
        loss = self.calc_gradients(input_dict)

        return loss.item()

    def update_multiagent_tensors(self, value_preds, returns, actions,
                                  rnn_masks):
        batch_size = self.batch_size
        ma_batch_size = self.num_actors * self.num_agents * self.num_steps
        value_preds = value_preds.view(self.num_actors, self.num_agents,
                                       self.num_steps,
                                       self.value_size).transpose(0, 1)
        returns = returns.view(self.num_actors, self.num_agents,
                               self.num_steps,
                               self.value_size).transpose(0, 1)
        value_preds = value_preds.contiguous().view(
            ma_batch_size, self.value_size)[:batch_size]
        returns = returns.contiguous().view(ma_batch_size,
                                            self.value_size)[:batch_size]

        if self.use_joint_obs_actions:
            assert (
                len(actions.size()) == 2,
                'use_joint_obs_actions not yet supported in continuous environment for central value'
            )
            actions = actions.view(self.num_actors, self.num_agents,
                                   self.num_steps).transpose(0, 1)
            actions = actions.contiguous().view(batch_size, self.num_agents)

        if self.is_rnn:
            rnn_masks = rnn_masks.view(self.num_actors, self.num_agents,
                                       self.num_steps).transpose(0, 1)
            rnn_masks = rnn_masks.flatten(0)[:batch_size]
        return value_preds, returns, actions, rnn_masks

    def train_net(self):
        self.train()
        loss = 0
        for _ in range(self.mini_epoch):
            for idx in range(len(self.dataset)):
                loss += self.train_critic(self.dataset[idx])
        avg_loss = loss / (self.mini_epoch * self.num_minibatches)

        if self.writter != None:
            self.writter.add_scalar('losses/cval_loss', avg_loss, self.frame)

        self.frame += self.batch_size
        return avg_loss

    def calc_gradients(self, batch):
        obs_batch = self._preproc_obs(batch['obs'])
        value_preds_batch = batch['old_values']
        returns_batch = batch['returns']
        actions_batch = batch['actions']
        rnn_masks_batch = batch.get('rnn_masks')

        batch_dict = {
            'obs': obs_batch,
            'actions': actions_batch,
            'seq_length': self.seq_len
        }
        if self.is_rnn:
            batch_dict['rnn_states'] = batch['rnn_states']

        values, _ = self.forward(batch_dict)
        loss = common_losses.critic_loss(value_preds_batch, values,
                                         self.e_clip, returns_batch,
                                         self.clip_value)
        losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch)
        loss = losses[0]
        if self.multi_gpu:
            self.optimizer.zero_grad()
        else:
            for param in self.model.parameters():
                param.grad = None
        loss.backward()

        #TODO: Refactor this ugliest code of they year
        if self.config['truncate_grads']:
            if self.multi_gpu:
                self.optimizer.synchronize()
                #self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.grad_norm)
                with self.optimizer.skip_synchronize():
                    self.optimizer.step()
            else:
                #self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.grad_norm)
                self.optimizer.step()
        else:
            self.optimizer.step()

        return loss
Beispiel #9
0
class PpoPlayerDiscrete(BasePlayer):
    def __init__(self, config):
        BasePlayer.__init__(self, config)

        self.network = config['network']
        if type(self.action_space) is gym.spaces.Discrete:
            self.actions_num = self.action_space.n
            self.is_multi_discrete = False
        if type(self.action_space) is gym.spaces.Tuple:
            self.actions_num = [action.n for action in self.action_space]
            self.is_multi_discrete = True
        self.mask = [False]

        self.normalize_input = self.config['normalize_input']

        obs_shape = self.obs_shape
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_agents,
            'value_size': self.value_size
        }

        self.model = self.network.build(config)
        self.model.to(self.device)
        self.model.eval()
        self.is_rnn = self.model.is_rnn()
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)
            self.running_mean_std.eval()

    def get_masked_action(self, obs, action_masks, is_determenistic=True):
        if self.has_batch_dimension == False:
            obs = unsqueeze_obs(obs)
        obs = self._preproc_obs(obs)
        action_masks = torch.Tensor(action_masks).to(self.device)
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'obs': obs,
            'action_masks': action_masks,
            'rnn_states': self.states
        }
        self.model.eval()

        with torch.no_grad():
            neglogp, value, action, logits, self.states = self.model(
                input_dict)
        logits = res_dict['logits']
        action = res_dict['actions']
        self.states = res_dict['rnn_states']
        if self.is_multi_discrete:
            if is_determenistic:
                action = [
                    torch.argmax(logit.detach(), axis=-1).squeeze()
                    for logit in logits
                ]
                return torch.stack(action, dim=-1)
            else:
                return action.squeeze().detach()
        else:
            if is_determenistic:
                return torch.argmax(logits.detach(), axis=-1).squeeze()
            else:
                return action.squeeze().detach()

    def get_action(self, obs, is_determenistic=False):
        if self.has_batch_dimension == False:
            obs = unsqueeze_obs(obs)
        obs = self._preproc_obs(obs)
        self.model.eval()
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'obs': obs,
            'rnn_states': self.states
        }
        with torch.no_grad():
            res_dict = self.model(input_dict)
        logits = res_dict['logits']
        action = res_dict['actions']
        self.states = res_dict['rnn_states']
        if self.is_multi_discrete:
            if is_determenistic:
                action = [
                    torch.argmax(logit.detach(), axis=1).squeeze()
                    for logit in logits
                ]
                return torch.stack(action, dim=-1)
            else:
                return action.squeeze().detach()
        else:
            if is_determenistic:
                return torch.argmax(logits.detach(), axis=-1).squeeze()
            else:
                return action.squeeze().detach()

    def restore(self, fn):
        checkpoint = torch_ext.load_checkpoint(fn)
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(
                checkpoint['running_mean_std'])

    def reset(self):
        self.init_rnn()
Beispiel #10
0
    def __init__(self, base_name, config):
        self.config = config
        self.env_config = config.get('env_config', {})
        self.num_actors = config['num_actors']
        self.env_name = config['env_name']

        self.env_info = config.get('env_info')
        if self.env_info is None:
            self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
            self.env_info = self.vec_env.get_env_info()

        self.ppo_device = config.get('device', 'cuda:0')
        print('Env info:')
        print(self.env_info)
        self.value_size = self.env_info.get('value_size',1)
        self.observation_space = self.env_info['observation_space']
        self.weight_decay = config.get('weight_decay', 0.0)
        self.use_action_masks = config.get('use_action_masks', False)
        self.is_train = config.get('is_train', True)

        self.central_value_config = self.config.get('central_value_config', None)
        self.has_central_value = self.central_value_config is not None

        if self.has_central_value:
            self.state_space = self.env_info.get('state_space', None)
            self.state_shape = None
            if self.state_space.shape != None:
                self.state_shape = self.state_space.shape

        self.self_play_config = self.config.get('self_play_config', None)
        self.has_self_play_config = self.self_play_config is not None

        self.self_play = config.get('self_play', False)
        self.save_freq = config.get('save_frequency', 0)
        self.save_best_after = config.get('save_best_after', 100)
        self.print_stats = config.get('print_stats', True)
        self.rnn_states = None
        self.name = base_name

        self.ppo = config['ppo']
        self.max_epochs = self.config.get('max_epochs', 1e6)

        self.is_adaptive_lr = config['lr_schedule'] == 'adaptive'
        self.linear_lr = config['lr_schedule'] == 'linear'
        self.schedule_type = config.get('schedule_type', 'legacy')
        if self.is_adaptive_lr:
            self.lr_threshold = config['lr_threshold']
            self.scheduler = schedulers.AdaptiveScheduler(self.lr_threshold)
        elif self.linear_lr:
            self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']), 
                max_steps=self.max_epochs, 
                apply_to_entropy=config.get('schedule_entropy', False),
                start_entropy_coef=config.get('entropy_coef'))
        else:
            self.scheduler = schedulers.IdentityScheduler()

        self.e_clip = config['e_clip']
        self.clip_value = config['clip_value']
        self.network = config['network']
        self.rewards_shaper = config['reward_shaper']
        self.num_agents = self.env_info.get('agents', 1)
        self.steps_num = config['steps_num']
        self.seq_len = self.config.get('seq_length', 4)
        self.normalize_advantage = config['normalize_advantage']
        self.normalize_input = self.config['normalize_input']
        self.normalize_value = self.config.get('normalize_value', False)

        self.obs_shape = self.observation_space.shape
 
        self.critic_coef = config['critic_coef']
        self.grad_norm = config['grad_norm']
        self.gamma = self.config['gamma']
        self.tau = self.config['tau']

        self.games_to_track = self.config.get('games_to_track', 100)
        self.game_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device)
        self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
        self.obs = None
        self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation
        self.batch_size = self.steps_num * self.num_actors * self.num_agents
        self.batch_size_envs = self.steps_num * self.num_actors
        self.minibatch_size = self.config['minibatch_size']
        self.mini_epochs_num = self.config['mini_epochs']
        self.num_minibatches = self.batch_size // self.minibatch_size
        assert(self.batch_size % self.minibatch_size == 0)

        self.last_lr = self.config['learning_rate']
        self.frame = 0
        self.update_time = 0
        self.last_mean_rewards = -100500
        self.play_time = 0
        self.epoch_num = 0
        
        self.entropy_coef = self.config['entropy_coef']
        self.writer = SummaryWriter('runs/' + config['name'] + datetime.now().strftime("_%d-%H-%M-%S"))

        if self.normalize_value:
            self.value_mean_std = RunningMeanStd((1,)).to(self.ppo_device)

        self.is_tensor_obses = False

        self.last_rnn_indices = None
        self.last_state_indices = None

        #self_play
        if self.has_self_play_config:
            print('Initializing SelfPlay Manager')
            self.self_play_manager = SelfPlayManager(self.self_play_config, self.writer)
        
        # features
        self.algo_observer = config['features']['observer']
Beispiel #11
0
class A2CBase:
    def __init__(self, base_name, config):
        self.config = config
        self.env_config = config.get('env_config', {})
        self.num_actors = config['num_actors']
        self.env_name = config['env_name']

        self.env_info = config.get('env_info')
        if self.env_info is None:
            self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
            self.env_info = self.vec_env.get_env_info()

        self.ppo_device = config.get('device', 'cuda:0')
        print('Env info:')
        print(self.env_info)
        self.value_size = self.env_info.get('value_size',1)
        self.observation_space = self.env_info['observation_space']
        self.weight_decay = config.get('weight_decay', 0.0)
        self.use_action_masks = config.get('use_action_masks', False)
        self.is_train = config.get('is_train', True)

        self.central_value_config = self.config.get('central_value_config', None)
        self.has_central_value = self.central_value_config is not None

        if self.has_central_value:
            self.state_space = self.env_info.get('state_space', None)
            self.state_shape = None
            if self.state_space.shape != None:
                self.state_shape = self.state_space.shape

        self.self_play_config = self.config.get('self_play_config', None)
        self.has_self_play_config = self.self_play_config is not None

        self.self_play = config.get('self_play', False)
        self.save_freq = config.get('save_frequency', 0)
        self.save_best_after = config.get('save_best_after', 100)
        self.print_stats = config.get('print_stats', True)
        self.rnn_states = None
        self.name = base_name

        self.ppo = config['ppo']
        self.max_epochs = self.config.get('max_epochs', 1e6)

        self.is_adaptive_lr = config['lr_schedule'] == 'adaptive'
        self.linear_lr = config['lr_schedule'] == 'linear'
        self.schedule_type = config.get('schedule_type', 'legacy')
        if self.is_adaptive_lr:
            self.lr_threshold = config['lr_threshold']
            self.scheduler = schedulers.AdaptiveScheduler(self.lr_threshold)
        elif self.linear_lr:
            self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']), 
                max_steps=self.max_epochs, 
                apply_to_entropy=config.get('schedule_entropy', False),
                start_entropy_coef=config.get('entropy_coef'))
        else:
            self.scheduler = schedulers.IdentityScheduler()

        self.e_clip = config['e_clip']
        self.clip_value = config['clip_value']
        self.network = config['network']
        self.rewards_shaper = config['reward_shaper']
        self.num_agents = self.env_info.get('agents', 1)
        self.steps_num = config['steps_num']
        self.seq_len = self.config.get('seq_length', 4)
        self.normalize_advantage = config['normalize_advantage']
        self.normalize_input = self.config['normalize_input']
        self.normalize_value = self.config.get('normalize_value', False)

        self.obs_shape = self.observation_space.shape
 
        self.critic_coef = config['critic_coef']
        self.grad_norm = config['grad_norm']
        self.gamma = self.config['gamma']
        self.tau = self.config['tau']

        self.games_to_track = self.config.get('games_to_track', 100)
        self.game_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device)
        self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
        self.obs = None
        self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation
        self.batch_size = self.steps_num * self.num_actors * self.num_agents
        self.batch_size_envs = self.steps_num * self.num_actors
        self.minibatch_size = self.config['minibatch_size']
        self.mini_epochs_num = self.config['mini_epochs']
        self.num_minibatches = self.batch_size // self.minibatch_size
        assert(self.batch_size % self.minibatch_size == 0)

        self.last_lr = self.config['learning_rate']
        self.frame = 0
        self.update_time = 0
        self.last_mean_rewards = -100500
        self.play_time = 0
        self.epoch_num = 0
        
        self.entropy_coef = self.config['entropy_coef']
        self.writer = SummaryWriter('runs/' + config['name'] + datetime.now().strftime("_%d-%H-%M-%S"))

        if self.normalize_value:
            self.value_mean_std = RunningMeanStd((1,)).to(self.ppo_device)

        self.is_tensor_obses = False

        self.last_rnn_indices = None
        self.last_state_indices = None

        #self_play
        if self.has_self_play_config:
            print('Initializing SelfPlay Manager')
            self.self_play_manager = SelfPlayManager(self.self_play_config, self.writer)
        
        # features
        self.algo_observer = config['features']['observer']

    def set_eval(self):
        self.model.eval()
        if self.normalize_input:
            self.running_mean_std.eval()
        if self.normalize_value:
            value = self.value_mean_std.eval()

    def set_train(self):
        self.model.train()
        if self.normalize_input:
            self.running_mean_std.train()
        if self.normalize_value:
            value = self.value_mean_std.train()

    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_action_values(self, obs):
        processed_obs = self._preproc_obs(obs['obs'])
        self.model.eval()
        input_dict = {
            'is_train': False,
            'prev_actions': None, 
            'obs' : processed_obs,
            'rnn_states' : self.rnn_states
        }

        with torch.no_grad():
            res_dict = self.model(input_dict)
            if self.has_central_value:
                states = obs['states']
                input_dict = {
                    'is_train': False,
                    'states' : states,
                    #'actions' : res_dict['action'],
                    #'rnn_states' : self.rnn_states
                }
                value = self.get_central_value(input_dict)
                res_dict['value'] = value
        if self.normalize_value:
            res_dict['value'] = self.value_mean_std(res_dict['value'], True)
        return res_dict

    def get_values(self, obs):
        with torch.no_grad():
            if self.has_central_value:
                states = obs['states']
                self.central_value_net.eval()
                input_dict = {
                    'is_train': False,
                    'states' : states,
                    'actions' : None,
                    'is_done': self.dones,
                }
                value = self.get_central_value(input_dict)
            else:
                self.model.eval()
                processed_obs = self._preproc_obs(obs['obs'])
                input_dict = {
                    'is_train': False,
                    'prev_actions': None, 
                    'obs' : processed_obs,
                    'rnn_states' : self.rnn_states
                }
                
                result = self.model(input_dict)
                value = result['value']

            if self.normalize_value:
                value = self.value_mean_std(value, True)
            return value

    def reset_envs(self):
        self.obs = self.env_reset()

    def init_tensors(self):
        if self.observation_space.dtype == np.uint8:
            torch_dtype = torch.uint8
        else:
            torch_dtype = torch.float32
        batch_size = self.num_agents * self.num_actors
 
        val_shape = (self.steps_num, batch_size, self.value_size)
        current_rewards_shape = (batch_size, self.value_size)
        self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
        self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device)
        self.dones = torch.zeros((batch_size,), dtype=torch.uint8, device=self.ppo_device)
        self.mb_obs = torch.zeros((self.steps_num, batch_size) + self.obs_shape, dtype=torch_dtype, device=self.ppo_device)

        if self.has_central_value:
            self.mb_vobs = torch.zeros((self.steps_num, self.num_actors) + self.state_shape, dtype=torch_dtype, device=self.ppo_device)

        self.mb_rewards = torch.zeros(val_shape, dtype = torch.float32, device=self.ppo_device)
        self.mb_values = torch.zeros(val_shape, dtype = torch.float32, device=self.ppo_device)
        self.mb_dones = torch.zeros((self.steps_num, batch_size), dtype = torch.uint8, device=self.ppo_device)
        self.mb_neglogpacs = torch.zeros((self.steps_num, batch_size), dtype = torch.float32, device=self.ppo_device)

        if self.is_rnn:
            self.rnn_states = self.model.get_default_rnn_state()
            self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]

            batch_size = self.num_agents * self.num_actors
            num_seqs = self.steps_num * batch_size // self.seq_len
            assert((self.steps_num * batch_size // self.num_minibatches) % self.seq_len == 0)
            self.mb_rnn_states = [torch.zeros((s.size()[0], num_seqs, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

    def init_rnn_from_model(self, model):
        self.is_rnn = self.model.is_rnn()

    def init_rnn_step(self, batch_size, mb_rnn_states):
        mb_rnn_states = self.mb_rnn_states
        mb_rnn_masks = torch.zeros(self.steps_num*batch_size, dtype = torch.float32, device=self.ppo_device)
        steps_mask = torch.arange(0, batch_size * self.steps_num, self.steps_num, dtype=torch.long, device=self.ppo_device)
        play_mask = torch.arange(0, batch_size, 1, dtype=torch.long, device=self.ppo_device)
        steps_state = torch.arange(0, batch_size * self.steps_num//self.seq_len, self.steps_num//self.seq_len, dtype=torch.long, device=self.ppo_device)
        indices = torch.zeros((batch_size), dtype = torch.long, device=self.ppo_device)
        return mb_rnn_masks, indices, steps_mask, steps_state, play_mask, mb_rnn_states

    def process_rnn_indices(self, mb_rnn_masks, indices, steps_mask, steps_state, mb_rnn_states):
        seq_indices = None
        if indices.max().item() >= self.steps_num:
            return seq_indices, True

        mb_rnn_masks[indices + steps_mask] = 1
        seq_indices = indices % self.seq_len
        state_indices = (seq_indices == 0).nonzero(as_tuple=False)
        state_pos = indices // self.seq_len
        rnn_indices = state_pos[state_indices] + steps_state[state_indices]

        for s, mb_s in zip(self.rnn_states, mb_rnn_states):
            mb_s[:, rnn_indices, :] = s[:, state_indices, :]

        self.last_rnn_indices = rnn_indices
        self.last_state_indices = state_indices
        return seq_indices, False

    def process_rnn_dones(self, all_done_indices, indices, seq_indices):
        if len(all_done_indices) > 0:
            shifts = self.seq_len - 1 - seq_indices[all_done_indices]
            indices[all_done_indices] += shifts
            for s in self.rnn_states:
                s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0
        indices += 1  


    def cast_obs(self, obs):
        if isinstance(obs, torch.Tensor):
            self.is_tensor_obses = True
        elif isinstance(obs, np.ndarray):
            assert(self.observation_space.dtype != np.int8)
            if self.observation_space.dtype == np.uint8:
                obs = torch.ByteTensor(obs).to(self.ppo_device)
            else:
                obs = torch.FloatTensor(obs).to(self.ppo_device)
        return obs

    def obs_to_tensors(self, obs):
        if isinstance(obs, dict):
            upd_obs = {}
            for key, value in obs.items():
                upd_obs[key] = self.cast_obs(value)
        else:
            upd_obs = {'obs' : self.cast_obs(obs)}
        return upd_obs

    def preprocess_actions(self, actions):
        if not self.is_tensor_obses:
            actions = actions.cpu().numpy()
        return actions

    def env_step(self, actions):
        actions = self.preprocess_actions(actions)
        obs, rewards, dones, infos = self.vec_env.step(actions)

        if self.is_tensor_obses:
            if self.value_size == 1:
                rewards = rewards.unsqueeze(1)
            return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos
        else:
            if self.value_size == 1:
                rewards = np.expand_dims(rewards, axis=1)
            return self.obs_to_tensors(obs), torch.from_numpy(rewards).to(self.ppo_device).float(), torch.from_numpy(dones).to(self.ppo_device), infos

    def env_reset(self):
        obs = self.vec_env.reset() 
        obs = self.obs_to_tensors(obs)
        return obs

    def discount_values(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards):
        lastgaelam = 0
        mb_advs = torch.zeros_like(mb_rewards)
        for t in reversed(range(self.steps_num)):
            if t == self.steps_num - 1:
                nextnonterminal = 1.0 - fdones
                nextvalues = last_extrinsic_values
            else:
                nextnonterminal = 1.0 - mb_fdones[t+1]
                nextvalues = mb_extrinsic_values[t+1]
            nextnonterminal = nextnonterminal.unsqueeze(1)
            delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal  - mb_extrinsic_values[t]
            mb_advs[t] = lastgaelam = delta + self.gamma * self.tau * nextnonterminal * lastgaelam
        return mb_advs

    def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_masks):
        lastgaelam = 0
        mb_advs = torch.zeros_like(mb_rewards)
        for t in reversed(range(self.steps_num)):
            if t == self.steps_num - 1:
                nextnonterminal = 1.0 - fdones
                nextvalues = last_extrinsic_values
            else:
                nextnonterminal = 1.0 - mb_fdones[t+1]
                nextvalues = mb_extrinsic_values[t+1]
            nextnonterminal = nextnonterminal.unsqueeze(1)
            delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal  - mb_extrinsic_values[t]
            mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * mb_masks[t].unsqueeze(1)
        return mb_advs

    def clear_stats(self):
        batch_size = self.num_agents * self.num_actors
        self.game_rewards.clear()
        self.game_lengths.clear()
        self.last_mean_rewards = -100500
        self.algo_observer.after_clear_stats()

    def update_epoch(self):
        pass

    def train(self):       
        pass

    def prepare_dataset(self, batch_dict):
        pass

    def train_epoch(self):
        pass

    def train_actor_critic(self, obs_dict, opt_step=True):
        pass 

    def calc_gradients(self, opt_step):
        pass

    def get_central_value(self, obs_dict):
        return self.central_value_net.get_value(obs_dict)

    def train_central_value(self):
        return self.central_value_net.train_net()

    def get_full_state_weights(self):
        state = self.get_weights()

        state['epoch'] = self.epoch_num
        state['optimizer'] = self.optimizer.state_dict()      

        if self.has_central_value:
            state['assymetric_vf_nets'] = self.central_value_net.state_dict()
        return state

    def set_full_state_weights(self, weights):
        self.set_weights(weights)

        self.epoch_num = weights['epoch']
        if self.has_central_value:
            self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])
        self.optimizer.load_state_dict(weights['optimizer'])

    def get_weights(self):
        state = {'model': self.model.state_dict()}

        if self.normalize_input:
            state['running_mean_std'] = self.running_mean_std.state_dict()
        if self.normalize_value:
            state['reward_mean_std'] = self.value_mean_std.state_dict()   
        return state

    def get_stats_weights(self):
        state = {}
        if self.normalize_input:
            state['running_mean_std'] = self.running_mean_std.state_dict()
        if self.normalize_value:
            state['reward_mean_std'] = self.value_mean_std.state_dict()
        if self.has_central_value:
            state['assymetric_vf_mean_std'] = self.central_value_net.get_stats_weights()
        return state

    def set_stats_weights(self, weights):
        if self.normalize_input:
            self.running_mean_std.load_state_dict(weights['running_mean_std'])
        if self.normalize_value:
            self.value_mean_std.load_state_dict(weights['reward_mean_std'])
        if self.has_central_value:
            self.central_value_net.set_stats_weights(state['assymetric_vf_mean_std'])
  
    def set_weights(self, weights):
        self.model.load_state_dict(weights['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(weights['running_mean_std'])
        if self.normalize_value:
            self.value_mean_std.load_state_dict(weights['reward_mean_std'])

    def _preproc_obs(self, obs_batch):
        if obs_batch.dtype == torch.uint8:
            obs_batch = obs_batch.float() / 255.0
        #if len(obs_batch.size()) == 3:
        #    obs_batch = obs_batch.permute((0, 2, 1))
        if len(obs_batch.size()) == 4:
            obs_batch = obs_batch.permute((0, 3, 1, 2))
        if self.normalize_input:
            obs_batch = self.running_mean_std(obs_batch)
        return obs_batch

    def play_steps(self):
        mb_rnn_states = []
        epinfos = []

        mb_obs = self.mb_obs
        mb_rewards = self.mb_rewards
        mb_values = self.mb_values
        mb_dones = self.mb_dones
        
        tensors_dict = self.tensors_dict
        update_list = self.update_list
        update_dict = self.update_dict

        if self.has_central_value:
            mb_vobs = self.mb_vobs

        batch_size = self.num_agents * self.num_actors
        mb_rnn_masks = None

        for n in range(self.steps_num):
            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                res_dict = self.get_masked_action_values(self.obs, masks)
            else:
                res_dict = self.get_action_values(self.obs)

            mb_obs[n,:] = self.obs['obs']
            mb_dones[n,:] = self.dones
            for k in update_list:
                tensors_dict[k][n,:] = res_dict[k]

            if self.has_central_value:
                mb_vobs[n,:] = self.obs['states']

            self.obs, rewards, self.dones, infos = self.env_step(res_dict['action'])

            shaped_rewards = self.rewards_shaper(rewards)
            mb_rewards[n,:] = shaped_rewards

            self.current_rewards += rewards
            self.current_lengths += 1
            all_done_indices = self.dones.nonzero(as_tuple=False)
            done_indices = all_done_indices[::self.num_agents]
  
            self.game_rewards.update(self.current_rewards[done_indices])
            self.game_lengths.update(self.current_lengths[done_indices])

            self.algo_observer.process_infos(infos, done_indices)

            not_dones = 1.0 - self.dones.float()

            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones
        
        if self.has_central_value and self.central_value_net.use_joint_obs_actions:
            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                val_dict = self.get_masked_action_values(self.obs, masks)
            else:
                val_dict = self.get_action_values(self.obs)
            last_values = val_dict['value']
        else:
            last_values = self.get_values(self.obs)

        mb_extrinsic_values = mb_values
        last_extrinsic_values = last_values

        fdones = self.dones.float()
        mb_fdones = mb_dones.float()
        mb_advs = self.discount_values(fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards)
        mb_returns = mb_advs + mb_extrinsic_values
        batch_dict = {
            'obs' : mb_obs,
            'returns' : mb_returns,
            'dones' : mb_dones,
        }
        for k in update_list:
            batch_dict[update_dict[k]] = tensors_dict[k]

        if self.has_central_value:
            batch_dict['states'] = mb_vobs

        batch_dict = {k: swap_and_flatten01(v) for k, v in batch_dict.items()}

        return batch_dict

    def play_steps_rnn(self):
        mb_rnn_states = []
        epinfos = []

        mb_obs = self.mb_obs
        mb_values = self.mb_values.fill_(0)
        mb_rewards = self.mb_rewards.fill_(0)
        mb_dones = self.mb_dones.fill_(1)

        tensors_dict = self.tensors_dict
        update_list = self.update_list
        update_dict = self.update_dict

        if self.has_central_value:
            mb_vobs = self.mb_vobs

        batch_size = self.num_agents * self.num_actors
        mb_rnn_masks = None

        mb_rnn_masks, indices, steps_mask, steps_state, play_mask, mb_rnn_states = self.init_rnn_step(batch_size, mb_rnn_states)

        for n in range(self.steps_num):
            seq_indices, full_tensor = self.process_rnn_indices(mb_rnn_masks, indices, steps_mask, steps_state, mb_rnn_states)
            if full_tensor:
                break
            if self.has_central_value:
                self.central_value_net.pre_step_rnn(self.last_rnn_indices, self.last_state_indices)

            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                res_dict = self.get_masked_action_values(self.obs, masks)
            else:
                res_dict = self.get_action_values(self.obs)
                
            self.rnn_states = res_dict['rnn_state']

            mb_dones[indices, play_mask] = self.dones.byte()
            mb_obs[indices,play_mask] = self.obs['obs']   

            for k in update_list:
                tensors_dict[k][indices,play_mask] = res_dict[k]

            if self.has_central_value:
                mb_vobs[indices[::self.num_agents] ,play_mask[::self.num_agents]//self.num_agents] = self.obs['states']

            self.obs, rewards, self.dones, infos = self.env_step(res_dict['action'])

            shaped_rewards = self.rewards_shaper(rewards)

            mb_rewards[indices, play_mask] = shaped_rewards

            self.current_rewards += rewards
            self.current_lengths += 1
            all_done_indices = self.dones.nonzero(as_tuple=False)
            done_indices = all_done_indices[::self.num_agents]

            self.process_rnn_dones(all_done_indices, indices, seq_indices)  
            if self.has_central_value:
                self.central_value_net.post_step_rnn(all_done_indices)
        
            self.algo_observer.process_infos(infos, done_indices)

            fdones = self.dones.float()
            not_dones = 1.0 - self.dones.float()

            self.game_rewards.update(self.current_rewards[done_indices])
            self.game_lengths.update(self.current_lengths[done_indices])
            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones

        if self.has_central_value and self.central_value_net.use_joint_obs_actions:
            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                val_dict = self.get_masked_action_values(self.obs, masks)
            else:
                val_dict = self.get_action_values(self.obs)
            
            last_values = val_dict['value']
        else:
            last_values = self.get_values(self.obs)

        mb_extrinsic_values = mb_values
        last_extrinsic_values = last_values
        fdones = self.dones.float()
        mb_fdones = mb_dones.float()

        non_finished = (indices != self.steps_num).nonzero(as_tuple=False)
        ind_to_fill = indices[non_finished]
        mb_fdones[ind_to_fill,non_finished] = fdones[non_finished]
        mb_extrinsic_values[ind_to_fill,non_finished] = last_extrinsic_values[non_finished]
        fdones[non_finished] = 1.0
        last_extrinsic_values[non_finished] = 0

        mb_advs = self.discount_values_masks(fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_rnn_masks.view(-1,self.steps_num).transpose(0,1))

        mb_returns = mb_advs + mb_extrinsic_values

        batch_dict = {
            'obs' : mb_obs,
            'returns' : mb_returns,
            'dones' : mb_dones,
        }
        for k in update_list:
            batch_dict[update_dict[k]] = tensors_dict[k]

        if self.has_central_value:
            batch_dict['states'] = mb_vobs

        batch_dict = {k: swap_and_flatten01(v) for k, v in batch_dict.items()}

        batch_dict['rnn_states'] = mb_rnn_states
        batch_dict['rnn_masks'] = mb_rnn_masks

        return batch_dict
Beispiel #12
0
    def __init__(self, base_name, params):
        a2c_common.DiscreteA2CBase.__init__(self, base_name, params)
        obs_shape = self.obs_shape

        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'num_seqs': self.num_actors * self.num_agents,
            'value_size': self.env_info.get('value_size', 1),
            'normalize_value': self.normalize_value,
            'normalize_input': self.normalize_input,
        }

        self.model = self.network.build(config)
        self.model.to(self.ppo_device)
        if self.ewma_ppo:
            self.ewma_model = EwmaModel(self.model, ewma_decay=0.889)
        self.init_rnn_from_model(self.model)

        self.last_lr = float(self.last_lr)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    float(self.last_lr),
                                    eps=1e-08,
                                    weight_decay=self.weight_decay)

        if self.normalize_input:
            if isinstance(self.observation_space, gym.spaces.Dict):
                self.running_mean_std = RunningMeanStdObs(obs_shape).to(
                    self.ppo_device)
            else:
                self.running_mean_std = RunningMeanStd(obs_shape).to(
                    self.ppo_device)

        if self.has_central_value:
            cv_config = {
                'state_shape': self.state_shape,
                'value_size': self.value_size,
                'ppo_device': self.ppo_device,
                'num_agents': self.num_agents,
                'horizon_length': self.horizon_length,
                'num_actors': self.num_actors,
                'num_actions': self.actions_num,
                'seq_len': self.seq_len,
                'normalize_value': self.normalize_value,
                'network': self.central_value_config['network'],
                'config': self.central_value_config,
                'writter': self.writer,
                'max_epochs': self.max_epochs,
                'multi_gpu': self.multi_gpu
            }
            self.central_value_net = central_value.CentralValueTrain(
                **cv_config).to(self.ppo_device)

        self.use_experimental_cv = self.config.get('use_experimental_cv',
                                                   False)
        self.dataset = datasets.PPODataset(self.batch_size,
                                           self.minibatch_size,
                                           self.is_discrete, self.is_rnn,
                                           self.ppo_device, self.seq_len)

        if 'phasic_policy_gradients' in self.config:
            self.has_phasic_policy_gradients = True
            self.ppg_aux_loss = ppg_aux.PPGAux(
                self, self.config['phasic_policy_gradients'])
        if self.normalize_value:
            self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std
        self.has_value_loss = (self.has_central_value and self.use_experimental_cv) \
                            or (not self.has_phasic_policy_gradients and not self.has_central_value)
        self.algo_observer.after_init(self)