def main(args): transform = transforms.Compose([ transforms.Resize([64, 64], 1), transforms.ToTensor(), transforms.Normalize([.5], [.5]) ]) dataset = datasets.ImageFolder(root=args.data_path, transform=transform) data = [] for i, j in dataset: data.append(i) data = torch.stack(data) cvae = CVAE(img_size=data.shape[1:], z_dim=args.z_dim) pt_file = load_model(args.model_load_path, "*10.pt") cvae.load_state_dict(torch.load(pt_file)) cvae.eval() if args.decode: #z_data = [torch.randn(data.shape[0], args.z_dim), data] #_, rec_img = cvae.decoder(*z_data) #grid_img = make_result_img([data, rec_img], normalize=True, range=(-1., 1.)) #utils.save_image(grid_img, "{}CVAE_gen_result.png".format(args.log_path)) cond_img = data[None, 0].repeat([32, 1, 1, 1]) z_data = [torch.randn(cond_img.shape[0], args.z_dim), cond_img] _, rec_img = cvae.decoder(*z_data) #grid_img = make_result_img([rec_img], normalize=True, range=(-1., 1.)) #utils.save_image(grid_img, "{}CVAE_gen_result_same_cond.png".format(args.log_path)) for i in range(rec_img.shape[0]): utils.save_image(rec_img[i], "{}goal{:0>6d}.png".format(args.log_path, i), normalize=True, range=(-1., 1.)) if args.gen_seq: for i, d in enumerate(data): cond_img = data[None, i] z_data = [torch.randn(1, args.z_dim), cond_img] _, rec_img = cvae.decoder(*z_data) grid_img = make_result_img([cond_img, rec_img], normalize=True, range=(-1., 1.)) utils.save_image( grid_img, "{}res_state-goal/CVAE_gen_{:0>6d}.png".format( args.log_path, i)) utils.save_image(rec_img, "{}res_goal/CVAE_gen_{:0>6d}.png".format( args.log_path, i), normalize=True, range=(-1., 1.))
def main(args): transform = transforms.Compose([ transforms.Resize([64, 64], 1), transforms.ToTensor(), transforms.Normalize([.5], [.5]) ]) dataset = datasets.ImageFolder(root=args.data_path, transform=transform) data = [] for i, j in dataset: data.append(i) data = torch.stack(data) cvae = CVAE(img_size=data.shape[1:], z_dim=args.z_dim) cvae.eval() pt_files = glob.glob(os.path.join(args.model_load_path, "*.pt")) pt_files.sort() #data_1 = data[0][None, :, :, :].repeat([20, 1, 1, 1]) #data_2 = data[1][None, :, :, :].repeat([20, 1, 1, 1]) for i in range(len(pt_files)): print(pt_files[i]) cvae.load_state_dict(torch.load(pt_files[i])) #z_data = [torch.randn(data.shape[0], args.z_dim), data] z_data = [ torch.randn(32, args.z_dim), data[None, 0].repeat([32, 1, 1, 1]) ] _, rec_img = cvae.decoder(*z_data) grid_img = make_result_img( [data[None, 0].repeat([32, 1, 1, 1]), rec_img], normalize=True, range=(-1., 1.)) utils.save_image( grid_img, "{}CVAE_gen_result_{:0>2d}.png".format(args.log_path, i))
class NNModel(object): def __init__(self, args): self.log_path = args.log_path self.device = torch.device("cuda:0" if args.cuda else "cpu") self.img_size = args.img_size self.sample_num = args.sample_num self.transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize([64, 64], 1), transforms.ToTensor(), transforms.Normalize([.5], [.5]) ]) self.pil_transform = transforms.ToPILImage(mode="RGB") self.norm_scale = np.loadtxt(os.path.join(args.config_path, "norm_scale.txt"), dtype=np.float32, delimiter=",")[None] self.norm_min = np.loadtxt(os.path.join(args.config_path, "norm_min.txt"), dtype=np.float32, delimiter=",")[None] self.pb_list = torch.from_numpy( np.loadtxt(os.path.join(args.config_path, "pb_list.txt"), dtype=np.float32, delimiter=",")) self.kmeans = KMeans(n_clusters=2) self.kmeans.fit(self.pb_list) print("=" * 5, "Init LSTMPB", "=" * 5) self.rnn = LSTMPB(args, pb_unit=self.pb_list[5][None]) pt_file = load_model(args.model_load_path, "*/*LSTMPB*.pt") self.rnn.load_state_dict(torch.load(pt_file)) print("=" * 5, "Init VAE", "=" * 5) self.vae = VAE(img_size=args.img_size, z_dim=args.vae_z_dims) pt_file = load_model(args.model_load_path, "*/VAE*.pt") self.vae.load_state_dict(torch.load(pt_file)) self.vae.eval() print("=" * 5, "Init CVAE", "=" * 5) self.cvae = CVAE(img_size=args.img_size, z_dim=args.cvae_z_dims) pt_file = load_model(args.model_load_path, "*/*CVAE*.pt") self.cvae.load_state_dict(torch.load(pt_file)) self.cvae.eval() self.norm_mode = { "joint": [0, 1, 2, 3, 4], "visual": [5, 6, 7, 8, 9, 10, 11] } self.norm_mode[ "all"] = self.norm_mode["joint"] + self.norm_mode["visual"] self.global_step = 0 self.his_log = HistoryWindow(maxlen=args.window_size) #visualize current goal _, goal = self.vae.decoder(self.denorm(self.goal, "visual")) goal = ((goal[0] * .5 + .5) * 255).to(torch.int8) self.goal_img = self.pil_transform(goal) def on_predict(self, cur_joint, cur_img, state=None): cur_joint = torch.Tensor(cur_joint)[None] cur_img = self.transform(cur_img[:, :, ::-1])[None] utils.save_image(cur_img[0], "./result/visual_{:0>6d}.png".format( self.global_step), normalize=True, range=(-1, 1)) img_feature = self.vae.reparam(*self.vae.encoder(cur_img)) inputs = torch.cat([cur_joint, img_feature], axis=-1).detach() inputs = self.norm(inputs).to(torch.float32) outputs, state = self.rnn.step(inputs, state) outputs, state = outputs.detach().cpu(), \ (state[0].detach().cpu(), state[1].detach().cpu()) self.global_step += 1 return outputs, state, self.denorm(outputs).to(torch.float32) def off_predict(self, cur_joint, img_feature, state=None): assert isinstance(cur_joint, (list, np.ndarray)) assert isinstance(img_feature, (list, np.ndarray)) cur_joint = torch.Tensor(cur_joint).to(torch.float32)[None] img_feature = torch.Tensor(img_feature).to(torch.float32)[None] inputs = torch.cat([cur_joint, img_feature], axis=-1) outputs, state = self.rnn.step(inputs, state) outputs, state = outputs.detach().cpu(), \ (state[0].detach().cpu(), state[1].detach().cpu()) self.his_log.put([outputs, inputs, state]) return outputs, state, self.denorm(outputs).to(torch.float32) def gen_goal(self, visual_img): visual_img = self.transform(visual_img)[None].repeat( self.sample_num, 1, 1, 1) sampled_z = torch.randn(self.sample_num, self.cvae.z_dim) _, gen_goals = self.cvae.decoder(z=sampled_z, cond=visual_img) pb_list = self.vae.reparam(*self.vae.encoder(gen_goals)).detach().cpu() #for i in range(gen_goals.shape[0]): # utils.save_image(gen_goals[i], "{}gen_goal{:0>6d}.png".format("./", i),normalize=True, range=(-1., 1.)) pb_label = self.kmeans.predict(pb_list.numpy()) print(pb_label) pb_list = torch.stack( [pb_list[pb_label == 0].mean(0), pb_list[pb_label == 1].mean(0)]) _, goal_list = self.vae.decoder(pb_list) pb_list = self.norm(pb_list, "visual") goal_list = ((goal_list * .5 + .5) * 255).to(torch.int8) goal_list = [self.pil_transform(goal) for goal in goal_list] return goal_list, pb_list def pem(self): assert len(self.his_log), "the history window is empty!" for param in self.rnn.parameters(): param.requires_grad = False self.rnn.pb_unit = nn.Parameter(self.rnn.pb_unit, requires_grad=True) optim_param = [ param for param in self.rnn.parameters() if param.requires_grad == True ] optim = torch.optim.Adam(optim_param, lr=0.01) mse_loss = nn.MSELoss() pred_his, actual_his, state_his = self.his_log.get() pb_log = [] for i in range(80): log = [] cur_input = torch.cat([pred_his[:, 0, :5], actual_his[:, 0, 5:]], dim=-1) state = state_his[0] for step in range(1, len(state_his)): cur_input, state = self.rnn.step(cur_input, state) log.append(cur_input) log = torch.stack(log, dim=1) loss = mse_loss(log[0, :, 5:], actual_his[0, 1:, 5:]) + \ (self.rnn.pb_unit - self.pb_list).pow(2).mean() pb_log.append(self.rnn.pb_unit.data.clone()) loss.backward() optim.step() print("PEM loss, step {}, loss: {}".format(i, loss.item())) @property def goal(self): return self.rnn.pb_unit @goal.setter def goal(self, pb): if pb.ndim == 1: pb = torch.unsqueeze(pb, 0) self.rnn.pb_unit = pb def norm(self, inputs, mode="all"): assert mode in ["joint", "visual", "all"] i_slice = self.norm_mode[mode] return inputs * self.norm_scale[:, i_slice] + self.norm_min[:, i_slice] def denorm(self, outputs, mode="all"): assert mode in ["joint", "visual", "all"] i_slice = self.norm_mode[mode] return (outputs - self.norm_min[:, i_slice]) / self.norm_scale[:, i_slice]