def train(batch_size, max_episode_length=10):
    env = Paint(batch_size, max_episode_length)
    actor = ResNet(
        9, 18, (action_dim + 3) * n_frames_per_step
    )  # target, canvas, stepnum, coordconv 3 + 3 + 1 + 2
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(actor.parameters(), lr=1e-2)

    for step in range(50000):
        state, y_target = env.reset_with_gen()
        y_target = y_target.view(batch_size, -1)
        state = torch.cat(
            (
                state[:, :6].float() / 255,
                state[:, 6:7].float() / max_episode_length,
                coord.expand(state.shape[0], 2, 128, 128),
            ),
            1,
        )
        actor.zero_grad()
        y = actor(state)
        loss = loss_fn(y, y_target)
        loss.backward()
        optimizer.step()
        if step % 100 == 0:
            print("step %d: loss %f" % (step, loss))
Beispiel #2
0
class fastenv:
    def __init__(self, max_episode_length=10, env_batch=64, writer=None):
        self.max_episode_length = max_episode_length
        self.env_batch = env_batch
        self.env = Paint(self.env_batch, self.max_episode_length)
        # self.env.load_data()
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.writer = writer
        self.test = False
        self.log = 0

    def save_image_with_gen(self, log, step):
        for i in range(1):
            canvas = to_numpy(self.env.canvas[i].permute(1, 2, 0))
            self.writer.add_image(
                "images/%d-%d-%d_canvas.png" % (log, i, step), canvas, log)
        gt = to_numpy(self.env.gt[i].permute(1, 2, 0))
        self.writer.add_image("images/%d-%d_target.png" % (log, i), gt, log)

    def save_image(self, log, step):
        for i in range(self.env_batch):
            if self.env.imgid[i] <= 10:
                canvas = cv2.cvtColor(
                    (to_numpy(self.env.canvas[i].permute(1, 2, 0))),
                    cv2.COLOR_BGR2RGB)
                self.writer.add_image(
                    "{}/canvas_{}.png".format(str(self.env.imgid[i]),
                                              str(step)),
                    canvas,
                    log,
                )
        if step == self.max_episode_length:
            for i in range(self.env_batch):
                if self.env.imgid[i] < 50:
                    gt = cv2.cvtColor(
                        (to_numpy(self.env.gt[i].permute(1, 2, 0))),
                        cv2.COLOR_BGR2RGB)
                    canvas = cv2.cvtColor(
                        (to_numpy(self.env.canvas[i].permute(1, 2, 0))),
                        cv2.COLOR_BGR2RGB,
                    )
                    self.writer.add_image(
                        str(self.env.imgid[i]) + "/_target.png", gt, log)
                    self.writer.add_image(
                        str(self.env.imgid[i]) + "/_canvas.png", canvas, log)

    def step(self, action):
        with torch.no_grad():
            ob, r, d, _ = self.env.step(torch.tensor(action).to(device))
        if d[0]:
            if not self.test:
                self.dist = self.get_dist()
                for i in range(self.env_batch):
                    self.writer.add_scalar("train/dist", self.dist[i],
                                           self.log)
                    self.log += 1
        return ob, r, d, _

    def get_dist(self):
        return to_numpy((((self.env.gt.float() - self.env.canvas.float()) /
                          255)**2).mean(1).mean(1).mean(1))

    def reset(self, test=False, episode=0):
        self.test = test
        # ob = self.env.reset(self.test, episode * self.env_batch)
        ob = self.env.reset_with_gen(self.test, episode * self.env_batch)
        return ob