def collectData(agent): print('Start', agent.memory.size) disablePrint() i = agent.memory.size env = Environment(render=False).fruitbot while i > 0: obs = clean(env.reset()) hn = torch.zeros(2, 1, hidden_size, device=device) cn = torch.zeros(2, 1, hidden_size, device=device) while i > 0: i -= 1 # hn, cn = hn.detach(), cn.detach() act, obs_old, h0, c0, hn, cn = agent.choose(obs, hn, cn) obs, rew, done, _ = env.step(act) obs = agent.remember(obs_old.detach(), act, clean(obs).detach(), rew, h0.detach(), c0.detach(), hn.detach(), cn.detach(), int(not done)) env.render() if done: break env.close() enablePrint() print('Done') return agent.memory.memory
def collectData(info): i, location, ID = info print('Start', ID) disablePrint() agent = Agent(memory=i) env = Environment(render=False).fruitbot while i > 0: obs = clean(env.reset()) hn = torch.zeros(2, 1, hidden_size, device=device) cn = torch.zeros(2, 1, hidden_size, device=device) while i > 0: i -= 1 # hn, cn = hn.detach(), cn.detach() act, obs_old, h0, c0, hn, cn = agent.choose(obs, hn, cn) obs, rew, done, _ = env.step(act) obs = agent.remember(obs_old.detach(), act, clean(obs).detach(), rew, h0.detach(), c0.detach(), hn.detach(), cn.detach(), int(not done)) env.render() if done: break env.close() saveData(agent, location, ID) enablePrint() print('Done', ID) return os.getpid()