示例#1
0
 async def forward(self, iput):
     """
     iput:
         p = (p_iput, p_mask)
         q = (q_iput, q_mask) [optional]
         oput_size            [optional]
         true_oput            [optional]
         eval_oput            [optional]
     modes:
         no q + no true_oput + no eval_oput = rollout:                sample p(oput | iput, mask)
         no q + no true_oput +    eval_oput = evaluate:               sample p(oput | iput, mask),
                                                                          and compute error(oput, eval_oput)
         no q +    true_oput + no eval_oput = get_loss (act_arg):     compute -log p(oput | iput, mask)
            q +    true_oput + no eval_oput = get_loss (annotated):   compute -log p(oput | iput, mask)
                                                                             - log q(oput | iput, mask)
            q + no true_oput + no eval_oput = get_loss (unannotated): sample q(oput | iput, mask),
                                                                          and compute D[q(. | iput, mask)
                                                                                     || p(. | iput, mask)]
                                                                              and log q(oput | iput, mask)
     res:
         oput
         error    [in evaluate]
         loss     [in get_loss]
         log_p    [in get_loss]
         log_q    [in get_loss]
     """
     p_log_prob = self._get_log_prob(self.p_module, iput.p_iput,
                                     iput.get('p_mask'),
                                     iput.get('oput_size'))
     if 'q_iput' in iput:
         q_log_prob = self._get_log_prob(self.q_module, iput.q_iput,
                                         iput.get('q_mask'),
                                         iput.get('oput_size'))
         oput = iput.get('true_oput', self._sample(q_log_prob))
         log_p = self._log_prob(p_log_prob, oput)
         log_q = self._log_prob(q_log_prob, oput)
         if 'true_oput' in iput:
             loss = -log_p - log_q
         else:
             loss = self._dkl(q_log_prob, p_log_prob)
         res = DictTree(
             oput=oput,
             # TODO: have configurable entropy_weight
             loss=loss + ENTROPY_WEIGHT * (self._neg_entropy(p_log_prob) +
                                           self._neg_entropy(q_log_prob)),
             log_p=log_p,
             log_q=log_q,
         )
     else:
         oput = iput.get('true_oput', self._sample(p_log_prob))
         res = DictTree(oput=oput, )
         if 'true_oput' in iput:
             log_p = self._log_prob(p_log_prob, iput.true_oput)
             res.loss = -log_p + ENTROPY_WEIGHT * self._neg_entropy(
                 p_log_prob)
             res.log_p = log_p
             res.log_q = torch.zeros_like(log_p)
         if 'eval_oput' in iput:
             res.error = self._error(oput, iput.eval_oput)
     return res
示例#2
0
 async def forward(self, iput):
     stack = iput.mem_in.stack.copy()
     ret_name = iput.ret_name
     ret_val = iput.ret_val
     steps = []
     loss = []
     log_p = []
     log_q = []
     step_idx = 0
     while True:
         top = stack[-1]
         top_php = self.phps[top.name]
         step_iput = DictTree(
             is_root=(len(stack) == 1),
             arg=self._get_value(top.arg),
             cnt=top.cnt,
             ret_name=ret_name,
             ret_val=self._get_value(ret_val),
             obs=self._get_value(iput.obs),
         )
         if 'ctx' in iput:
             step_iput.ctx = iput.ctx
             if 'mem_out' in iput:
                 step = iput.mem_out.steps[step_idx]
                 step_iput.sub_name = step.sub_name
                 step_iput.sub_arg = step.sub_arg
                 step_idx += 1
             step_iput.act_name = iput.act_name
             step_iput.act_arg = iput.act_arg
         elif 'act_name' in iput:
             step_iput.act_name = iput.act_name
             step_iput.act_arg = iput.act_arg
         step_oput = await top_php(step_iput)
         steps.append(DictTree(
             name=top.name,
             arg=self._get_value(top.arg, teacher=False),
             cnt=top.cnt,
             ret_name=ret_name,
             ret_val=self._get_value(ret_val, teacher=False),
             sub_name=step_oput.sub_name,
             sub_arg=step_oput.sub_arg,
         ))
         if 'ctx' in iput:
             loss.extend(step_oput.loss)
             log_p.extend(step_oput.log_p)
             log_q.extend(step_oput.log_q)
         if step_oput.sub_name is None:
             # terminate php
             assert top.cnt > 0
             stack.pop()
             if stack:
                 ret_name = top.name
                 ret_val = step_oput.sub_arg
             else:
                 # terminate agent
                 act_name = None
                 act_arg = step_oput.sub_arg
                 break
         elif step_oput.sub_name in self.act_names:
             # take action
             stack[-1] = DictTree(top, cnt=top.cnt + 1)
             act_name = step_oput.sub_name
             act_arg = step_oput.sub_arg
             break
         else:
             # call php
             stack[-1] = DictTree(top, cnt=top.cnt + 1)
             ret_name = None
             ret_val = ret_val.new_empty(0)
             stack.append(DictTree(name=step_oput.sub_name, arg=step_oput.sub_arg, cnt=ret_val.new_zeros(1)))
     oput = DictTree(
         mem_out=DictTree(steps=steps, stack=stack),
     )
     if 'mem_out' in iput:
         assert torch_utils.eq(iput.mem_out, oput.mem_out)
     if 'ctx' in iput:
         oput.loss = loss
         oput.log_p = log_p
         oput.log_q = log_q
     elif 'act_name' in iput:
         oput.error = step_oput.error
     else:
         oput.act_name = act_name
         oput.act_arg = act_arg
     return oput