Пример #1
0
def main(args):
    transform = transforms.Compose([
        transforms.Resize([64, 64], 1),
        transforms.ToTensor(),
        transforms.Normalize([.5], [.5])
    ])
    dataset = datasets.ImageFolder(root=args.data_path, transform=transform)

    data = []
    for i, j in dataset:
        data.append(i)
    data = torch.stack(data)

    cvae = CVAE(img_size=data.shape[1:], z_dim=args.z_dim)
    pt_file = load_model(args.model_load_path, "*10.pt")
    cvae.load_state_dict(torch.load(pt_file))
    cvae.eval()

    if args.decode:
        #z_data = [torch.randn(data.shape[0], args.z_dim), data]
        #_, rec_img = cvae.decoder(*z_data)
        #grid_img = make_result_img([data, rec_img], normalize=True, range=(-1., 1.))
        #utils.save_image(grid_img, "{}CVAE_gen_result.png".format(args.log_path))

        cond_img = data[None, 0].repeat([32, 1, 1, 1])
        z_data = [torch.randn(cond_img.shape[0], args.z_dim), cond_img]
        _, rec_img = cvae.decoder(*z_data)
        #grid_img = make_result_img([rec_img], normalize=True, range=(-1., 1.))
        #utils.save_image(grid_img, "{}CVAE_gen_result_same_cond.png".format(args.log_path))
        for i in range(rec_img.shape[0]):
            utils.save_image(rec_img[i],
                             "{}goal{:0>6d}.png".format(args.log_path, i),
                             normalize=True,
                             range=(-1., 1.))

    if args.gen_seq:
        for i, d in enumerate(data):
            cond_img = data[None, i]
            z_data = [torch.randn(1, args.z_dim), cond_img]
            _, rec_img = cvae.decoder(*z_data)
            grid_img = make_result_img([cond_img, rec_img],
                                       normalize=True,
                                       range=(-1., 1.))
            utils.save_image(
                grid_img, "{}res_state-goal/CVAE_gen_{:0>6d}.png".format(
                    args.log_path, i))
            utils.save_image(rec_img,
                             "{}res_goal/CVAE_gen_{:0>6d}.png".format(
                                 args.log_path, i),
                             normalize=True,
                             range=(-1., 1.))
Пример #2
0
def main(args):
    transform = transforms.Compose([
        transforms.Resize([64, 64], 1),
        transforms.ToTensor(),
        transforms.Normalize([.5], [.5])
    ])
    dataset = datasets.ImageFolder(root=args.data_path, transform=transform)

    data = []
    for i, j in dataset:
        data.append(i)
    data = torch.stack(data)

    cvae = CVAE(img_size=data.shape[1:], z_dim=args.z_dim)
    cvae.eval()
    pt_files = glob.glob(os.path.join(args.model_load_path, "*.pt"))
    pt_files.sort()
    #data_1 = data[0][None, :, :, :].repeat([20, 1, 1, 1])
    #data_2 = data[1][None, :, :, :].repeat([20, 1, 1, 1])

    for i in range(len(pt_files)):
        print(pt_files[i])
        cvae.load_state_dict(torch.load(pt_files[i]))

        #z_data = [torch.randn(data.shape[0], args.z_dim), data]
        z_data = [
            torch.randn(32, args.z_dim), data[None, 0].repeat([32, 1, 1, 1])
        ]
        _, rec_img = cvae.decoder(*z_data)
        grid_img = make_result_img(
            [data[None, 0].repeat([32, 1, 1, 1]), rec_img],
            normalize=True,
            range=(-1., 1.))
        utils.save_image(
            grid_img, "{}CVAE_gen_result_{:0>2d}.png".format(args.log_path, i))
Пример #3
0
class NNModel(object):
    def __init__(self, args):
        self.log_path = args.log_path
        self.device = torch.device("cuda:0" if args.cuda else "cpu")
        self.img_size = args.img_size
        self.sample_num = args.sample_num
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize([64, 64], 1),
            transforms.ToTensor(),
            transforms.Normalize([.5], [.5])
        ])
        self.pil_transform = transforms.ToPILImage(mode="RGB")

        self.norm_scale = np.loadtxt(os.path.join(args.config_path,
                                                  "norm_scale.txt"),
                                     dtype=np.float32,
                                     delimiter=",")[None]
        self.norm_min = np.loadtxt(os.path.join(args.config_path,
                                                "norm_min.txt"),
                                   dtype=np.float32,
                                   delimiter=",")[None]
        self.pb_list = torch.from_numpy(
            np.loadtxt(os.path.join(args.config_path, "pb_list.txt"),
                       dtype=np.float32,
                       delimiter=","))

        self.kmeans = KMeans(n_clusters=2)
        self.kmeans.fit(self.pb_list)

        print("=" * 5, "Init LSTMPB", "=" * 5)
        self.rnn = LSTMPB(args, pb_unit=self.pb_list[5][None])
        pt_file = load_model(args.model_load_path, "*/*LSTMPB*.pt")
        self.rnn.load_state_dict(torch.load(pt_file))

        print("=" * 5, "Init VAE", "=" * 5)
        self.vae = VAE(img_size=args.img_size, z_dim=args.vae_z_dims)
        pt_file = load_model(args.model_load_path, "*/VAE*.pt")
        self.vae.load_state_dict(torch.load(pt_file))
        self.vae.eval()

        print("=" * 5, "Init CVAE", "=" * 5)
        self.cvae = CVAE(img_size=args.img_size, z_dim=args.cvae_z_dims)
        pt_file = load_model(args.model_load_path, "*/*CVAE*.pt")
        self.cvae.load_state_dict(torch.load(pt_file))
        self.cvae.eval()

        self.norm_mode = {
            "joint": [0, 1, 2, 3, 4],
            "visual": [5, 6, 7, 8, 9, 10, 11]
        }
        self.norm_mode[
            "all"] = self.norm_mode["joint"] + self.norm_mode["visual"]

        self.global_step = 0
        self.his_log = HistoryWindow(maxlen=args.window_size)

        #visualize current goal
        _, goal = self.vae.decoder(self.denorm(self.goal, "visual"))
        goal = ((goal[0] * .5 + .5) * 255).to(torch.int8)
        self.goal_img = self.pil_transform(goal)

    def on_predict(self, cur_joint, cur_img, state=None):
        cur_joint = torch.Tensor(cur_joint)[None]
        cur_img = self.transform(cur_img[:, :, ::-1])[None]
        utils.save_image(cur_img[0],
                         "./result/visual_{:0>6d}.png".format(
                             self.global_step),
                         normalize=True,
                         range=(-1, 1))

        img_feature = self.vae.reparam(*self.vae.encoder(cur_img))
        inputs = torch.cat([cur_joint, img_feature], axis=-1).detach()
        inputs = self.norm(inputs).to(torch.float32)

        outputs, state = self.rnn.step(inputs, state)
        outputs, state = outputs.detach().cpu(), \
                         (state[0].detach().cpu(), state[1].detach().cpu())
        self.global_step += 1
        return outputs, state, self.denorm(outputs).to(torch.float32)

    def off_predict(self, cur_joint, img_feature, state=None):
        assert isinstance(cur_joint, (list, np.ndarray))
        assert isinstance(img_feature, (list, np.ndarray))

        cur_joint = torch.Tensor(cur_joint).to(torch.float32)[None]
        img_feature = torch.Tensor(img_feature).to(torch.float32)[None]

        inputs = torch.cat([cur_joint, img_feature], axis=-1)
        outputs, state = self.rnn.step(inputs, state)
        outputs, state = outputs.detach().cpu(), \
                         (state[0].detach().cpu(), state[1].detach().cpu())

        self.his_log.put([outputs, inputs, state])
        return outputs, state, self.denorm(outputs).to(torch.float32)

    def gen_goal(self, visual_img):
        visual_img = self.transform(visual_img)[None].repeat(
            self.sample_num, 1, 1, 1)
        sampled_z = torch.randn(self.sample_num, self.cvae.z_dim)
        _, gen_goals = self.cvae.decoder(z=sampled_z, cond=visual_img)
        pb_list = self.vae.reparam(*self.vae.encoder(gen_goals)).detach().cpu()
        #for i in range(gen_goals.shape[0]):
        #    utils.save_image(gen_goals[i], "{}gen_goal{:0>6d}.png".format("./", i),normalize=True, range=(-1., 1.))

        pb_label = self.kmeans.predict(pb_list.numpy())
        print(pb_label)
        pb_list = torch.stack(
            [pb_list[pb_label == 0].mean(0), pb_list[pb_label == 1].mean(0)])
        _, goal_list = self.vae.decoder(pb_list)
        pb_list = self.norm(pb_list, "visual")
        goal_list = ((goal_list * .5 + .5) * 255).to(torch.int8)
        goal_list = [self.pil_transform(goal) for goal in goal_list]
        return goal_list, pb_list

    def pem(self):
        assert len(self.his_log), "the history window is empty!"
        for param in self.rnn.parameters():
            param.requires_grad = False
        self.rnn.pb_unit = nn.Parameter(self.rnn.pb_unit, requires_grad=True)
        optim_param = [
            param for param in self.rnn.parameters()
            if param.requires_grad == True
        ]
        optim = torch.optim.Adam(optim_param, lr=0.01)
        mse_loss = nn.MSELoss()

        pred_his, actual_his, state_his = self.his_log.get()
        pb_log = []
        for i in range(80):
            log = []
            cur_input = torch.cat([pred_his[:, 0, :5], actual_his[:, 0, 5:]],
                                  dim=-1)
            state = state_his[0]
            for step in range(1, len(state_his)):
                cur_input, state = self.rnn.step(cur_input, state)
                log.append(cur_input)
            log = torch.stack(log, dim=1)
            loss = mse_loss(log[0, :, 5:], actual_his[0, 1:, 5:]) + \
                   (self.rnn.pb_unit - self.pb_list).pow(2).mean()
            pb_log.append(self.rnn.pb_unit.data.clone())
            loss.backward()
            optim.step()
            print("PEM loss, step {}, loss: {}".format(i, loss.item()))

    @property
    def goal(self):
        return self.rnn.pb_unit

    @goal.setter
    def goal(self, pb):
        if pb.ndim == 1:
            pb = torch.unsqueeze(pb, 0)
        self.rnn.pb_unit = pb

    def norm(self, inputs, mode="all"):
        assert mode in ["joint", "visual", "all"]
        i_slice = self.norm_mode[mode]
        return inputs * self.norm_scale[:, i_slice] + self.norm_min[:, i_slice]

    def denorm(self, outputs, mode="all"):
        assert mode in ["joint", "visual", "all"]
        i_slice = self.norm_mode[mode]
        return (outputs - self.norm_min[:, i_slice]) / self.norm_scale[:,
                                                                       i_slice]