def train(batch_size, max_episode_length=10): env = Paint(batch_size, max_episode_length) actor = ResNet( 9, 18, (action_dim + 3) * n_frames_per_step ) # target, canvas, stepnum, coordconv 3 + 3 + 1 + 2 loss_fn = nn.MSELoss() optimizer = optim.Adam(actor.parameters(), lr=1e-2) for step in range(50000): state, y_target = env.reset_with_gen() y_target = y_target.view(batch_size, -1) state = torch.cat( ( state[:, :6].float() / 255, state[:, 6:7].float() / max_episode_length, coord.expand(state.shape[0], 2, 128, 128), ), 1, ) actor.zero_grad() y = actor(state) loss = loss_fn(y, y_target) loss.backward() optimizer.step() if step % 100 == 0: print("step %d: loss %f" % (step, loss))
class fastenv: def __init__(self, max_episode_length=10, env_batch=64, writer=None): self.max_episode_length = max_episode_length self.env_batch = env_batch self.env = Paint(self.env_batch, self.max_episode_length) # self.env.load_data() self.observation_space = self.env.observation_space self.action_space = self.env.action_space self.writer = writer self.test = False self.log = 0 def save_image_with_gen(self, log, step): for i in range(1): canvas = to_numpy(self.env.canvas[i].permute(1, 2, 0)) self.writer.add_image( "images/%d-%d-%d_canvas.png" % (log, i, step), canvas, log) gt = to_numpy(self.env.gt[i].permute(1, 2, 0)) self.writer.add_image("images/%d-%d_target.png" % (log, i), gt, log) def save_image(self, log, step): for i in range(self.env_batch): if self.env.imgid[i] <= 10: canvas = cv2.cvtColor( (to_numpy(self.env.canvas[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB) self.writer.add_image( "{}/canvas_{}.png".format(str(self.env.imgid[i]), str(step)), canvas, log, ) if step == self.max_episode_length: for i in range(self.env_batch): if self.env.imgid[i] < 50: gt = cv2.cvtColor( (to_numpy(self.env.gt[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB) canvas = cv2.cvtColor( (to_numpy(self.env.canvas[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB, ) self.writer.add_image( str(self.env.imgid[i]) + "/_target.png", gt, log) self.writer.add_image( str(self.env.imgid[i]) + "/_canvas.png", canvas, log) def step(self, action): with torch.no_grad(): ob, r, d, _ = self.env.step(torch.tensor(action).to(device)) if d[0]: if not self.test: self.dist = self.get_dist() for i in range(self.env_batch): self.writer.add_scalar("train/dist", self.dist[i], self.log) self.log += 1 return ob, r, d, _ def get_dist(self): return to_numpy((((self.env.gt.float() - self.env.canvas.float()) / 255)**2).mean(1).mean(1).mean(1)) def reset(self, test=False, episode=0): self.test = test # ob = self.env.reset(self.test, episode * self.env_batch) ob = self.env.reset_with_gen(self.test, episode * self.env_batch) return ob