Exemple #1
0
 def calc_v(self, x, net=None, use_cache=True):
     '''
     Forward-pass to calculate the predicted state-value from critic_net.
     '''
     if self.shared:  # output: policy, value
         if use_cache:  # uses cache from calc_pdparam to prevent double-pass
             v_pred = self.v_pred
         else:
             net = self.net if net is None else net
             v_pred = net(x)[-1].view(-1)
     else:
         net = self.critic_net if net is None else net
         v_pred = net(x).view(-1)
     return v_pred
Exemple #2
0
 def calc_q(self, state, action, net):
     '''Forward-pass to calculate the predicted state-action-value from q1_net.'''
     if not self.body.is_discrete and action.dim(
     ) == 1:  # handle shape consistency for single continuous action
         action = action.unsqueeze(dim=-1)
     q_pred = net(state, action).view(-1)
     return q_pred
Exemple #3
0
 def calc_pdparam(self, x, net=None):
     '''
     The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
     '''
     net = self.net if net is None else net
     pdparam = net(x)
     return pdparam
Exemple #4
0
 def calc_pdparam(self, x, net=None):
     '''
     To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs.
     The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
     '''
     net = self.net if net is None else net
     pdparam = net(x)
     return pdparam
Exemple #5
0
 def calc_pdparam(self, x, evaluate=True, net=None):
     '''
     The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
     '''
     net = self.net if net is None else net
     if evaluate:
         pdparam = net.wrap_eval(x)
     else:
         net.train()
         pdparam = net(x)
     logger.debug(f'pdparam: {pdparam}')
     return pdparam
Exemple #6
0
 def calc_pdparam(self, x, evaluate=True, net=None):
     '''
     To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs.
     The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
     '''
     net = self.net if net is None else net
     if evaluate:
         pdparam = net.wrap_eval(x)
     else:
         net.train()
         pdparam = net(x)
     logger.debug(f'pdparam: {pdparam}')
     return pdparam
Exemple #7
0
 def calc_v(self, x, evaluate=True, net=None):
     '''
     Forward-pass to calculate the predicted state-value from critic.
     '''
     net = self.net if net is None else net
     if self.shared:  # output: policy, value
         if evaluate:
             out = net.wrap_eval(x)
         else:
             net.train()
             out = net(x)
         v = out[-1].squeeze_(dim=1)  # get value only
     else:
         if evaluate:
             out = self.critic.wrap_eval(x)
         else:
             self.critic.train()
             out = self.critic(x)
         v = out.squeeze_(dim=1)
     logger.debug(f'v: {v}')
     return v
Exemple #8
0
 def calc_pdparam(self, x, evaluate=True, net=None):
     '''
     The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
     '''
     net = self.net if net is None else net
     if evaluate:
         pdparam = net.wrap_eval(x)
     else:
         net.train()
         pdparam = net(x)
     if self.share_architecture:
         # MLPHeterogenousTails, get front (no critic)
         if self.body.is_discrete:
             pdparam = pdparam[0]
         else:
             if len(pdparam) == 2:  # only (loc, scale) and (v)
                 pdparam = pdparam[0]
             else:
                 pdparam = pdparam[:-1]
     logger.debug(f'pdparam: {pdparam}')
     return pdparam
Exemple #9
0
 def calc_v(self, x, evaluate=True, net=None):
     '''
     Forward-pass to calculate the predicted state-value from critic.
     '''
     net = self.net if net is None else net
     if self.share_architecture:
         if evaluate:
             out = net.wrap_eval(x)
         else:
             net.train()
             out = net(x)
         # MLPHeterogenousTails, get last
         v = out[-1].squeeze_(dim=1)
     else:
         if evaluate:
             out = self.critic.wrap_eval(x)
         else:
             self.critic.train()
             out = self.critic(x)
         v = out.squeeze_(dim=1)
     logger.debug(f'v: {v}')
     return v
Exemple #10
0
 def calc_q(self, state, action, net):
     '''Forward-pass to calculate the predicted state-action-value from q1_net.'''
     q_pred = net(state, action).view(-1)
     return q_pred
Exemple #11
0
 def calc_q(self, state, action, net=None):
     '''Forward-pass to calculate the predicted state-action-value from q1_net.'''
     x = torch.cat((state, action), dim=-1)
     net = self.q1_net if net is None else net
     q_pred = net(x).view(-1)
     return q_pred