예제 #1
0
    def get_action(self, ob, sample=True, *args, **kwargs):
        self.eval_mode()
        t_ob = {key: torch_float(ob[key], device=cfg.alg.device) for key in ob}
        act_dist_cont, act_dist_disc, val = self.get_act_val(t_ob)
        action_cont = action_from_dist(act_dist_cont, sample=sample)
        action_discrete = action_from_dist(act_dist_disc, sample=sample)
        #print('456', action_discrete.shape, act_dist_disc)
        #print('123', action_cont.shape, act_dist_cont)
        log_prob_disc = action_log_prob(action_discrete, act_dist_disc)
        log_prob_cont = action_log_prob(action_cont, act_dist_cont)
        entropy_disc = action_entropy(act_dist_disc, log_prob_disc)
        entropy_cont = action_entropy(act_dist_cont, log_prob_cont)
        #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1))
        log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1)
        #print(log_prob_cont.shape, log_prob_disc.shape)
        entropy = entropy_cont + torch.sum(entropy_disc, axis=1)

        action_info = dict(log_prob=torch_to_np(log_prob),
                           entropy=torch_to_np(entropy),
                           val=torch_to_np(val))
        #print("cd", action_cont.shape, action_discrete.shape)
        action = np.concatenate(
            (torch_to_np(action_cont), torch_to_np(action_discrete)), axis=1)
        #print("action:", action)

        return action, action_info
예제 #2
0
    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        #print(ob.shape)
        #from IPython import embed
        #embed()

        state = data['state']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']
        done = data['done']
        hidden_state = data['hidden_state']
        hidden_state = hidden_state.permute(1, 0, 2)

        act_dist_cont, act_dist_disc, val, out_hidden_state = self.get_act_val({"ob": ob, "state": state},
                                                             hidden_state=hidden_state,
                                                             done=done)
        action_cont = action[:, :, :self.dim_cont]
        action_discrete = action[:, :, self.dim_cont:]
        #print('456', action_discrete.shape, act_dist_disc)
        #print('123', action_cont.shape, act_dist_cont)
        log_prob_disc = action_log_prob(action_discrete, act_dist_disc)
        log_prob_cont = action_log_prob(action_cont, act_dist_cont)
        entropy_disc = action_entropy(act_dist_disc, log_prob_disc)
        entropy_cont = action_entropy(act_dist_cont, log_prob_cont)
        #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1))
        if len(log_prob_disc.shape) == 2:
            log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1)
            #print(log_prob_cont.shape, log_prob_disc.shape)
            entropy = entropy_cont + torch.sum(entropy_disc, axis=1)
        else:
            log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=2)
            #print(log_prob_cont.shape, log_prob_disc.shape)
            entropy = entropy_cont + torch.sum(entropy_disc, axis=2)

        #print(val.shape, entropy.shape, log_prob.shape)
        #if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
        #    raise ValueError('val, entropy, log_prob should be 1-dim!')
        processed_data = dict(
            val=val,
            old_val=old_val,
            ret=ret,
            log_prob=log_prob,
            old_log_prob=old_log_prob,
            adv=adv,
            entropy=entropy
        )
        return processed_data
예제 #3
0
    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            if val is not None:
                data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']
        done = data['done']
        hidden_state = data['hidden_state']
        hidden_state = hidden_state.permute(1, 0, 2)

        act_dist, val, out_hidden_state = self.get_act_val(
            ob, hidden_state=hidden_state, done=done)
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        processed_data = dict(val=val,
                              old_val=old_val,
                              ret=ret,
                              log_prob=log_prob,
                              old_log_prob=old_log_prob,
                              adv=adv,
                              entropy=entropy)
        return processed_data
예제 #4
0
    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        state = data['state']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']

        act_dist, val = self.get_act_val({"ob": ob, "state": state})
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
            raise ValueError('val, entropy, log_prob should be 1-dim!')
        processed_data = dict(val=val,
                              old_val=old_val,
                              ret=ret,
                              log_prob=log_prob,
                              old_log_prob=old_log_prob,
                              adv=adv,
                              entropy=entropy)
        return processed_data
예제 #5
0
 def get_action(self, ob, sample=True, *args, **kwargs):
     self.eval_mode()
     t_ob = torch_float(ob, device=cfg.alg.device)
     act_dist, val = self.get_act_val(t_ob)
     action = action_from_dist(act_dist, sample=sample)
     log_prob = action_log_prob(action, act_dist)
     entropy = action_entropy(act_dist, log_prob)
     action_info = dict(log_prob=torch_to_np(log_prob),
                        entropy=torch_to_np(entropy),
                        val=torch_to_np(val))
     return torch_to_np(action), action_info
예제 #6
0
 def get_action(self, ob, sample=True, hidden_state=None, *args, **kwargs):
     self.eval_mode()
     t_ob = torch.from_numpy(ob).float().to(cfg.alg.device).unsqueeze(dim=1)
     act_dist, val, out_hidden_state = self.get_act_val(
         t_ob, hidden_state=hidden_state)
     action = action_from_dist(act_dist, sample=sample)
     log_prob = action_log_prob(action, act_dist)
     entropy = action_entropy(act_dist, log_prob)
     action_info = dict(
         log_prob=torch_to_np(log_prob.squeeze(1)),
         entropy=torch_to_np(entropy.squeeze(1)),
         val=torch_to_np(val.squeeze(1)),
     )
     return torch_to_np(action.squeeze(1)), action_info, out_hidden_state
예제 #7
0
    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        state = data['state']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']

        act_dist_cont, act_dist_disc, val = self.get_act_val({
            "ob": ob,
            "state": state
        })
        action_cont = action[:, :self.dim_cont]
        action_discrete = action[:, self.dim_cont:]
        log_prob_disc = action_log_prob(action_discrete, act_dist_disc)
        log_prob_cont = action_log_prob(action_cont, act_dist_cont)
        entropy_disc = action_entropy(act_dist_disc, log_prob_disc)
        entropy_cont = action_entropy(act_dist_cont, log_prob_cont)
        #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1))
        log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1)
        entropy = entropy_cont + torch.sum(entropy_disc, axis=1)

        if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
            raise ValueError('val, entropy, log_prob should be 1-dim!')
        processed_data = dict(val=val,
                              old_val=old_val,
                              ret=ret,
                              log_prob=log_prob,
                              old_log_prob=old_log_prob,
                              adv=adv,
                              entropy=entropy)
        return processed_data
예제 #8
0
    def get_action(self, ob, sample=True, hidden_state=None, *args, **kwargs):
        self.eval_mode()

        if type(ob) is dict:
            t_ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            t_ob = torch.from_numpy(ob).float().to(
                cfg.alg.device).unsqueeze(dim=1)

        act_dist, val, out_hidden_state = self.get_act_val(
            t_ob, hidden_state=hidden_state)
        action = action_from_dist(act_dist, sample=sample)
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        in_hidden_state = torch_to_np(
            hidden_state) if hidden_state is not None else hidden_state
        action_info = dict(log_prob=torch_to_np(log_prob.squeeze(1)),
                           entropy=torch_to_np(entropy.squeeze(1)),
                           val=torch_to_np(val.squeeze(1)),
                           in_hidden_state=in_hidden_state)
        return torch_to_np(action.squeeze(1)), action_info, out_hidden_state