async def forward(self, iput): """ iput: p = (p_iput, p_mask) q = (q_iput, q_mask) [optional] oput_size [optional] true_oput [optional] eval_oput [optional] modes: no q + no true_oput + no eval_oput = rollout: sample p(oput | iput, mask) no q + no true_oput + eval_oput = evaluate: sample p(oput | iput, mask), and compute error(oput, eval_oput) no q + true_oput + no eval_oput = get_loss (act_arg): compute -log p(oput | iput, mask) q + true_oput + no eval_oput = get_loss (annotated): compute -log p(oput | iput, mask) - log q(oput | iput, mask) q + no true_oput + no eval_oput = get_loss (unannotated): sample q(oput | iput, mask), and compute D[q(. | iput, mask) || p(. | iput, mask)] and log q(oput | iput, mask) res: oput error [in evaluate] loss [in get_loss] log_p [in get_loss] log_q [in get_loss] """ p_log_prob = self._get_log_prob(self.p_module, iput.p_iput, iput.get('p_mask'), iput.get('oput_size')) if 'q_iput' in iput: q_log_prob = self._get_log_prob(self.q_module, iput.q_iput, iput.get('q_mask'), iput.get('oput_size')) oput = iput.get('true_oput', self._sample(q_log_prob)) log_p = self._log_prob(p_log_prob, oput) log_q = self._log_prob(q_log_prob, oput) if 'true_oput' in iput: loss = -log_p - log_q else: loss = self._dkl(q_log_prob, p_log_prob) res = DictTree( oput=oput, # TODO: have configurable entropy_weight loss=loss + ENTROPY_WEIGHT * (self._neg_entropy(p_log_prob) + self._neg_entropy(q_log_prob)), log_p=log_p, log_q=log_q, ) else: oput = iput.get('true_oput', self._sample(p_log_prob)) res = DictTree(oput=oput, ) if 'true_oput' in iput: log_p = self._log_prob(p_log_prob, iput.true_oput) res.loss = -log_p + ENTROPY_WEIGHT * self._neg_entropy( p_log_prob) res.log_p = log_p res.log_q = torch.zeros_like(log_p) if 'eval_oput' in iput: res.error = self._error(oput, iput.eval_oput) return res
async def forward(self, iput): stack = iput.mem_in.stack.copy() ret_name = iput.ret_name ret_val = iput.ret_val steps = [] loss = [] log_p = [] log_q = [] step_idx = 0 while True: top = stack[-1] top_php = self.phps[top.name] step_iput = DictTree( is_root=(len(stack) == 1), arg=self._get_value(top.arg), cnt=top.cnt, ret_name=ret_name, ret_val=self._get_value(ret_val), obs=self._get_value(iput.obs), ) if 'ctx' in iput: step_iput.ctx = iput.ctx if 'mem_out' in iput: step = iput.mem_out.steps[step_idx] step_iput.sub_name = step.sub_name step_iput.sub_arg = step.sub_arg step_idx += 1 step_iput.act_name = iput.act_name step_iput.act_arg = iput.act_arg elif 'act_name' in iput: step_iput.act_name = iput.act_name step_iput.act_arg = iput.act_arg step_oput = await top_php(step_iput) steps.append(DictTree( name=top.name, arg=self._get_value(top.arg, teacher=False), cnt=top.cnt, ret_name=ret_name, ret_val=self._get_value(ret_val, teacher=False), sub_name=step_oput.sub_name, sub_arg=step_oput.sub_arg, )) if 'ctx' in iput: loss.extend(step_oput.loss) log_p.extend(step_oput.log_p) log_q.extend(step_oput.log_q) if step_oput.sub_name is None: # terminate php assert top.cnt > 0 stack.pop() if stack: ret_name = top.name ret_val = step_oput.sub_arg else: # terminate agent act_name = None act_arg = step_oput.sub_arg break elif step_oput.sub_name in self.act_names: # take action stack[-1] = DictTree(top, cnt=top.cnt + 1) act_name = step_oput.sub_name act_arg = step_oput.sub_arg break else: # call php stack[-1] = DictTree(top, cnt=top.cnt + 1) ret_name = None ret_val = ret_val.new_empty(0) stack.append(DictTree(name=step_oput.sub_name, arg=step_oput.sub_arg, cnt=ret_val.new_zeros(1))) oput = DictTree( mem_out=DictTree(steps=steps, stack=stack), ) if 'mem_out' in iput: assert torch_utils.eq(iput.mem_out, oput.mem_out) if 'ctx' in iput: oput.loss = loss oput.log_p = log_p oput.log_q = log_q elif 'act_name' in iput: oput.error = step_oput.error else: oput.act_name = act_name oput.act_arg = act_arg return oput
async def forward(self, iput): """ iput: is_root arg cnt ret = (ret_name, ret_val) obs ctx [optional] sub = (sub_name, sub_arg) [optional] act = (act_name, act_arg) [optional] modes: no ctx + no sub + no act = rollout: sample p(sub | arg, cnt, ret, obs) no ctx + no sub + act = evaluate: sample p(sub | arg, cnt, ret, obs), and compute error(sub, iput.act) ctx + sub + act = get_loss (annotated): compute -log p(sub | arg, cnt, ret, obs) - log q(sub | arg, cnt, ret, ctx, act) ctx + no sub + act = get_loss (unannotated): sample q(sub | arg, cnt, ret, ctx, act), and compute D[q(. | arg, cnt, ret, ctx, act) || p(. | arg, cnt, ret, obs)] and log q(sub | arg, cnt, ret, ctx, act) oput: sub = (sub_name, sub_arg) error [in evaluate] loss [in get_loss] log_q [in get_loss] """ assert not iput.arg.requires_grad assert not iput.cnt.requires_grad assert not iput.ret_val.requires_grad assert not iput.obs.requires_grad has_sub = 'sub_name' in iput if has_sub: assert not iput.sub_arg.requires_grad has_act = 'act_name' in iput if has_act: assert not iput.act_arg.requires_grad iput.ret_val = torch_utils.pad(iput.ret_val, self.ret_in_size, 0) iput.ret1h = torch_utils.one_hot(iput.ret_name, self.sub_names, dtype=iput.arg.dtype, device=iput.arg.device) if 'ctx' in iput: sub_iput = DictTree( p_iput=self._get_iput(iput, is_posterior=False, is_arg=False)[None], p_mask=self._get_mask(iput, is_posterior=False)[None], q_iput=self._get_iput(iput, is_posterior=True, is_arg=False)[None], q_mask=self._get_mask(iput, is_posterior=True)[None], ) if has_sub: sub1h = torch_utils.one_hot(iput.sub_name, self.sub_names, dtype=iput.arg.dtype, device=iput.arg.device) sub_iput.true_oput = sub1h[None] sub_oput = await self.sub_module(sub_iput) sub_name = iput.sub_name else: sub_oput = await self.sub_module(sub_iput) [sub1h] = sub_oput.oput sub_name = torch_utils.lookup(sub1h, self.sub_names) [sub_loss] = sub_oput.loss [sub_log_p] = sub_oput.log_p [sub_log_q] = sub_oput.log_q if self.arg_module is None: assert self.arg_out_size == 0 if has_sub: sub_arg = iput.sub_arg else: sub_arg = iput.arg.new_empty(0) arg_loss = iput.arg.new_zeros(()) arg_log_p = iput.arg.new_zeros(()) arg_log_q = iput.arg.new_zeros(()) else: sub_arg_size = self.ret_out_size if sub_name is None else self.subs[ sub_name].arg_in_size sub_is_arg = self._sub_is_arg(iput | DictTree(sub_name=sub_name)) arg_iput = DictTree( p_iput=self._get_iput(iput | DictTree(sub1h=sub1h), is_posterior=False, is_arg=True)[None], oput_size=iput.arg.new_full((), sub_arg_size, dtype=torch.int64)[None], ) if sub_is_arg: if has_sub: # noinspection PyUnresolvedReferences assert (iput.sub_arg == iput.act_arg).all() # TODO: use iput.act_arg as auxiliary task arg_iput.true_oput = torch_utils.pad( iput.act_arg, self.arg_out_size, 0)[None] else: arg_iput.q_iput = self._get_iput(iput | DictTree(sub1h=sub1h), is_posterior=True, is_arg=True)[None] if has_sub: arg_iput.true_oput = torch_utils.pad( iput.sub_arg, self.arg_out_size, 0)[None] arg_oput = await self.arg_module(arg_iput) if has_sub: sub_arg = iput.sub_arg else: [sub_arg] = arg_oput.oput.detach()[:, :sub_arg_size] [arg_loss] = arg_oput.loss [arg_log_p] = arg_oput.log_p [arg_log_q] = arg_oput.log_q oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, loss=[sub_loss, arg_loss], log_p=[sub_log_p, arg_log_p], log_q=[sub_log_q, arg_log_q], ) else: sub_iput = DictTree( p_iput=self._get_iput(iput, is_posterior=False, is_arg=False)[None], p_mask=self._get_mask(iput, is_posterior=False)[None], ) if has_act and iput.act_name in self.sub_names: sub_iput.eval_oput = torch_utils.one_hot( iput.act_name, self.sub_names, dtype=iput.arg.dtype, device=iput.arg.device)[None] sub_oput = await self.sub_module(sub_iput) [sub1h] = sub_oput.oput sub_name = torch_utils.lookup(sub1h, self.sub_names) if has_act and iput.act_name in self.sub_names: [sub_error] = sub_oput.error.detach() else: sub_error = True if self.arg_module is None: assert self.arg_out_size == 0 sub_arg = iput.arg.new_empty(0) arg_error = False else: sub_arg_size = self.ret_out_size if sub_name is None else self.subs[ sub_name].arg_in_size arg_iput = DictTree( p_iput=self._get_iput(iput | DictTree(sub1h=sub1h), is_posterior=False, is_arg=True)[None], oput_size=iput.arg.new_full((), sub_arg_size, dtype=torch.int64)[None], ) if has_act and not sub_error: arg_iput.eval_oput = torch_utils.pad( iput.act_arg, self.arg_out_size, 0)[None] arg_oput = await self.arg_module(arg_iput) [sub_arg] = arg_oput.oput.detach()[:, :sub_arg_size] if has_act and not sub_error: [arg_error] = arg_oput.error.detach() else: arg_error = True oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, ) if has_act: oput.error = bool(sub_error) or bool(arg_error) return oput