コード例 #1
0
 async def forward(self, iput):
     """
     iput:
         p = (p_iput, p_mask)
         q = (q_iput, q_mask) [optional]
         oput_size            [optional]
         true_oput            [optional]
         eval_oput            [optional]
     modes:
         no q + no true_oput + no eval_oput = rollout:                sample p(oput | iput, mask)
         no q + no true_oput +    eval_oput = evaluate:               sample p(oput | iput, mask),
                                                                          and compute error(oput, eval_oput)
         no q +    true_oput + no eval_oput = get_loss (act_arg):     compute -log p(oput | iput, mask)
            q +    true_oput + no eval_oput = get_loss (annotated):   compute -log p(oput | iput, mask)
                                                                             - log q(oput | iput, mask)
            q + no true_oput + no eval_oput = get_loss (unannotated): sample q(oput | iput, mask),
                                                                          and compute D[q(. | iput, mask)
                                                                                     || p(. | iput, mask)]
                                                                              and log q(oput | iput, mask)
     res:
         oput
         error    [in evaluate]
         loss     [in get_loss]
         log_p    [in get_loss]
         log_q    [in get_loss]
     """
     p_log_prob = self._get_log_prob(self.p_module, iput.p_iput,
                                     iput.get('p_mask'),
                                     iput.get('oput_size'))
     if 'q_iput' in iput:
         q_log_prob = self._get_log_prob(self.q_module, iput.q_iput,
                                         iput.get('q_mask'),
                                         iput.get('oput_size'))
         oput = iput.get('true_oput', self._sample(q_log_prob))
         log_p = self._log_prob(p_log_prob, oput)
         log_q = self._log_prob(q_log_prob, oput)
         if 'true_oput' in iput:
             loss = -log_p - log_q
         else:
             loss = self._dkl(q_log_prob, p_log_prob)
         res = DictTree(
             oput=oput,
             # TODO: have configurable entropy_weight
             loss=loss + ENTROPY_WEIGHT * (self._neg_entropy(p_log_prob) +
                                           self._neg_entropy(q_log_prob)),
             log_p=log_p,
             log_q=log_q,
         )
     else:
         oput = iput.get('true_oput', self._sample(p_log_prob))
         res = DictTree(oput=oput, )
         if 'true_oput' in iput:
             log_p = self._log_prob(p_log_prob, iput.true_oput)
             res.loss = -log_p + ENTROPY_WEIGHT * self._neg_entropy(
                 p_log_prob)
             res.log_p = log_p
             res.log_q = torch.zeros_like(log_p)
         if 'eval_oput' in iput:
             res.error = self._error(oput, iput.eval_oput)
     return res
コード例 #2
0
 async def forward(self, iput):
     stack = iput.mem_in.stack.copy()
     ret_name = iput.ret_name
     ret_val = iput.ret_val
     steps = []
     loss = []
     log_p = []
     log_q = []
     step_idx = 0
     while True:
         top = stack[-1]
         top_php = self.phps[top.name]
         step_iput = DictTree(
             is_root=(len(stack) == 1),
             arg=self._get_value(top.arg),
             cnt=top.cnt,
             ret_name=ret_name,
             ret_val=self._get_value(ret_val),
             obs=self._get_value(iput.obs),
         )
         if 'ctx' in iput:
             step_iput.ctx = iput.ctx
             if 'mem_out' in iput:
                 step = iput.mem_out.steps[step_idx]
                 step_iput.sub_name = step.sub_name
                 step_iput.sub_arg = step.sub_arg
                 step_idx += 1
             step_iput.act_name = iput.act_name
             step_iput.act_arg = iput.act_arg
         elif 'act_name' in iput:
             step_iput.act_name = iput.act_name
             step_iput.act_arg = iput.act_arg
         step_oput = await top_php(step_iput)
         steps.append(DictTree(
             name=top.name,
             arg=self._get_value(top.arg, teacher=False),
             cnt=top.cnt,
             ret_name=ret_name,
             ret_val=self._get_value(ret_val, teacher=False),
             sub_name=step_oput.sub_name,
             sub_arg=step_oput.sub_arg,
         ))
         if 'ctx' in iput:
             loss.extend(step_oput.loss)
             log_p.extend(step_oput.log_p)
             log_q.extend(step_oput.log_q)
         if step_oput.sub_name is None:
             # terminate php
             assert top.cnt > 0
             stack.pop()
             if stack:
                 ret_name = top.name
                 ret_val = step_oput.sub_arg
             else:
                 # terminate agent
                 act_name = None
                 act_arg = step_oput.sub_arg
                 break
         elif step_oput.sub_name in self.act_names:
             # take action
             stack[-1] = DictTree(top, cnt=top.cnt + 1)
             act_name = step_oput.sub_name
             act_arg = step_oput.sub_arg
             break
         else:
             # call php
             stack[-1] = DictTree(top, cnt=top.cnt + 1)
             ret_name = None
             ret_val = ret_val.new_empty(0)
             stack.append(DictTree(name=step_oput.sub_name, arg=step_oput.sub_arg, cnt=ret_val.new_zeros(1)))
     oput = DictTree(
         mem_out=DictTree(steps=steps, stack=stack),
     )
     if 'mem_out' in iput:
         assert torch_utils.eq(iput.mem_out, oput.mem_out)
     if 'ctx' in iput:
         oput.loss = loss
         oput.log_p = log_p
         oput.log_q = log_q
     elif 'act_name' in iput:
         oput.error = step_oput.error
     else:
         oput.act_name = act_name
         oput.act_arg = act_arg
     return oput
コード例 #3
0
ファイル: php.py プロジェクト: royf/hvil
 async def forward(self, iput):
     """
     iput:
         is_root
         arg
         cnt
         ret = (ret_name, ret_val)
         obs
         ctx                       [optional]
         sub = (sub_name, sub_arg) [optional]
         act = (act_name, act_arg) [optional]
     modes:
         no ctx + no sub + no act = rollout:                sample p(sub | arg, cnt, ret, obs)
         no ctx + no sub +    act = evaluate:               sample p(sub | arg, cnt, ret, obs),
                                                                and compute error(sub, iput.act)
            ctx +    sub +    act = get_loss (annotated):   compute -log p(sub | arg, cnt, ret, obs)
                                                                   - log q(sub | arg, cnt, ret, ctx, act)
            ctx + no sub +    act = get_loss (unannotated): sample q(sub | arg, cnt, ret, ctx, act),
                                                                and compute D[q(. | arg, cnt, ret, ctx, act)
                                                                           || p(. | arg, cnt, ret, obs)]
                                                                    and log q(sub | arg, cnt, ret, ctx, act)
     oput:
         sub = (sub_name, sub_arg)
         error                     [in evaluate]
         loss                      [in get_loss]
         log_q                     [in get_loss]
     """
     assert not iput.arg.requires_grad
     assert not iput.cnt.requires_grad
     assert not iput.ret_val.requires_grad
     assert not iput.obs.requires_grad
     has_sub = 'sub_name' in iput
     if has_sub:
         assert not iput.sub_arg.requires_grad
     has_act = 'act_name' in iput
     if has_act:
         assert not iput.act_arg.requires_grad
     iput.ret_val = torch_utils.pad(iput.ret_val, self.ret_in_size, 0)
     iput.ret1h = torch_utils.one_hot(iput.ret_name,
                                      self.sub_names,
                                      dtype=iput.arg.dtype,
                                      device=iput.arg.device)
     if 'ctx' in iput:
         sub_iput = DictTree(
             p_iput=self._get_iput(iput, is_posterior=False,
                                   is_arg=False)[None],
             p_mask=self._get_mask(iput, is_posterior=False)[None],
             q_iput=self._get_iput(iput, is_posterior=True,
                                   is_arg=False)[None],
             q_mask=self._get_mask(iput, is_posterior=True)[None],
         )
         if has_sub:
             sub1h = torch_utils.one_hot(iput.sub_name,
                                         self.sub_names,
                                         dtype=iput.arg.dtype,
                                         device=iput.arg.device)
             sub_iput.true_oput = sub1h[None]
             sub_oput = await self.sub_module(sub_iput)
             sub_name = iput.sub_name
         else:
             sub_oput = await self.sub_module(sub_iput)
             [sub1h] = sub_oput.oput
             sub_name = torch_utils.lookup(sub1h, self.sub_names)
         [sub_loss] = sub_oput.loss
         [sub_log_p] = sub_oput.log_p
         [sub_log_q] = sub_oput.log_q
         if self.arg_module is None:
             assert self.arg_out_size == 0
             if has_sub:
                 sub_arg = iput.sub_arg
             else:
                 sub_arg = iput.arg.new_empty(0)
             arg_loss = iput.arg.new_zeros(())
             arg_log_p = iput.arg.new_zeros(())
             arg_log_q = iput.arg.new_zeros(())
         else:
             sub_arg_size = self.ret_out_size if sub_name is None else self.subs[
                 sub_name].arg_in_size
             sub_is_arg = self._sub_is_arg(iput
                                           | DictTree(sub_name=sub_name))
             arg_iput = DictTree(
                 p_iput=self._get_iput(iput | DictTree(sub1h=sub1h),
                                       is_posterior=False,
                                       is_arg=True)[None],
                 oput_size=iput.arg.new_full((),
                                             sub_arg_size,
                                             dtype=torch.int64)[None],
             )
             if sub_is_arg:
                 if has_sub:
                     # noinspection PyUnresolvedReferences
                     assert (iput.sub_arg == iput.act_arg).all()
                 # TODO: use iput.act_arg as auxiliary task
                 arg_iput.true_oput = torch_utils.pad(
                     iput.act_arg, self.arg_out_size, 0)[None]
             else:
                 arg_iput.q_iput = self._get_iput(iput
                                                  | DictTree(sub1h=sub1h),
                                                  is_posterior=True,
                                                  is_arg=True)[None]
                 if has_sub:
                     arg_iput.true_oput = torch_utils.pad(
                         iput.sub_arg, self.arg_out_size, 0)[None]
             arg_oput = await self.arg_module(arg_iput)
             if has_sub:
                 sub_arg = iput.sub_arg
             else:
                 [sub_arg] = arg_oput.oput.detach()[:, :sub_arg_size]
             [arg_loss] = arg_oput.loss
             [arg_log_p] = arg_oput.log_p
             [arg_log_q] = arg_oput.log_q
         oput = DictTree(
             sub_name=sub_name,
             sub_arg=sub_arg,
             loss=[sub_loss, arg_loss],
             log_p=[sub_log_p, arg_log_p],
             log_q=[sub_log_q, arg_log_q],
         )
     else:
         sub_iput = DictTree(
             p_iput=self._get_iput(iput, is_posterior=False,
                                   is_arg=False)[None],
             p_mask=self._get_mask(iput, is_posterior=False)[None],
         )
         if has_act and iput.act_name in self.sub_names:
             sub_iput.eval_oput = torch_utils.one_hot(
                 iput.act_name,
                 self.sub_names,
                 dtype=iput.arg.dtype,
                 device=iput.arg.device)[None]
         sub_oput = await self.sub_module(sub_iput)
         [sub1h] = sub_oput.oput
         sub_name = torch_utils.lookup(sub1h, self.sub_names)
         if has_act and iput.act_name in self.sub_names:
             [sub_error] = sub_oput.error.detach()
         else:
             sub_error = True
         if self.arg_module is None:
             assert self.arg_out_size == 0
             sub_arg = iput.arg.new_empty(0)
             arg_error = False
         else:
             sub_arg_size = self.ret_out_size if sub_name is None else self.subs[
                 sub_name].arg_in_size
             arg_iput = DictTree(
                 p_iput=self._get_iput(iput | DictTree(sub1h=sub1h),
                                       is_posterior=False,
                                       is_arg=True)[None],
                 oput_size=iput.arg.new_full((),
                                             sub_arg_size,
                                             dtype=torch.int64)[None],
             )
             if has_act and not sub_error:
                 arg_iput.eval_oput = torch_utils.pad(
                     iput.act_arg, self.arg_out_size, 0)[None]
             arg_oput = await self.arg_module(arg_iput)
             [sub_arg] = arg_oput.oput.detach()[:, :sub_arg_size]
             if has_act and not sub_error:
                 [arg_error] = arg_oput.error.detach()
             else:
                 arg_error = True
         oput = DictTree(
             sub_name=sub_name,
             sub_arg=sub_arg,
         )
         if has_act:
             oput.error = bool(sub_error) or bool(arg_error)
     return oput