async def __call__(self, iput): sub_name = [None, 'ptr1right', 'ptr2right'][int( (iput.cnt[0] + 1) % 3)] sub_arg = torch.empty(0) oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, ) return oput
async def __call__(self, iput): if (iput.ret_name is None) or (iput.obs[10] != 1) or (iput.obs[21] != 1): sub_name = 'lshift' else: sub_name = None sub_arg = torch.empty(0) oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, ) return oput
async def __call__(self, iput): if (iput.ret_name is None): sub_name = 'ptr2right' elif (iput.obs[10] != 1) or (iput.obs[21] != 1) or ( iput.ret_name == 'ptr2right'): sub_name = 'bstep' elif (iput.obs[10] == 1) and (iput.obs[21] == 1): sub_name = None sub_arg = torch.empty(0) oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, ) return oput
async def __call__(self, iput): if iput.cnt == 0: # p = [0.01] * 2 # p[1] = 0.99 # a = np.random.choice(range(2), p=p) a = 1 sub_name = f'A{a}' else: sub_name = None sub_arg = torch.empty(0) oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, ) return oput
async def __call__(self, iput): digit = iput.obs.argmax() if iput.cnt == 0: if digit % 2 == 0: sub_name = 'P0' else: sub_name = 'P1' else: sub_name = None sub_arg = torch.empty(0) oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, ) return oput
async def __call__(self, iput): env = iput.env env.pointer2 = max(env.pointer2 - 1, 1) num1 = torch_utils.one_hot(env.list_to_sort[env.pointer1 - 1], 10) num2 = torch_utils.one_hot(env.list_to_sort[env.pointer2 - 1], 10) env.last_obs = [ num1, torch.tensor([ float((env.pointer1 == 1) or (env.pointer1 == env.length)) ]), num2, torch.tensor([ float((env.pointer2 == 1) or (env.pointer2 == env.length)) ]) ] return torch.empty(0)
async def __call__(self, iput): if (int(iput.obs[:10].argmax()) > int(iput.obs[11:21].argmax()) ) and (iput.ret_name != 'rshift'): sub_name = 'swap' elif (int(iput.obs[:10].argmax()) <= int( iput.obs[11:21].argmax())) and (iput.ret_name != 'rshift'): sub_name = 'rshift' elif (iput.ret_name == 'rshift'): sub_name = None sub_arg = torch.empty(0) oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, ) return oput
async def __call__(self, iput): if (iput.ret_name in [None, 'resetptr' ]) and (iput.cnt[0] != (2 * iput.arg - 2)): sub_name = 'bubble' elif (iput.ret_name == 'bubble') and (iput.cnt[0] != (2 * iput.arg - 2)): sub_name = 'resetptr' elif iput.cnt[0] == (2 * iput.arg - 2): sub_name = None sub_arg = torch.empty(0) oput = DictTree( sub_name=sub_name, sub_arg=sub_arg, ) return oput
def rollout(self, env): self.eval() # set to evaluation mode with torch.no_grad(): init_arg = env.reset() memory = self.reset(init_arg) trace = DictTree( metadata=DictTree(init_arg=self._get_value(init_arg, teacher=False)), data=DictTree(steps=[]), ) ret_name = None ret_val = torch.empty(0) done = False while not done: obs = env.observe() iput = DictTree( mem_in=memory, ret_name=ret_name, ret_val=ret_val, obs=obs, ) oput = asyncio.get_event_loop().run_until_complete(self(iput)) trace.data.steps.append(DictTree( mem_in=memory, ret_name=ret_name, ret_val=self._get_value(ret_val, teacher=False), obs=self._get_value(obs, teacher=False), mem_out=oput.mem_out, act_name=oput.act_name, act_arg=oput.act_arg, )) if oput.act_name is None: done = True else: memory = oput.mem_out ret_name = oput.act_name ret_val = asyncio.get_event_loop().run_until_complete(env.step(oput.act_name, oput.act_arg)) self._postprocess(trace) self.train() # set to training mode return trace
async def __call__(self, iput): return torch.empty(0)
def reset(self): self.digit = None return DictTree(value=torch.empty(0))