def conditional_generator_function(self, c_, c_next_, obs): ''' This doesn't do anything. ''' c_ = undiscretize(c_, self.discretization_bins, self.P.unif_range) c_next_ = undiscretize(c_next_, self.discretization_bins, self.P.unif_range) z_ = from_numpy_to_var(np.random.randn(c_.shape[0], self.rand_z_dim)) _, next_observation = self.G(z_, from_numpy_to_var(c_), from_numpy_to_var(c_next_)) return next_observation.data.cpu().numpy()
def continuous_transition_function(self, c_): c_ = undiscretize(c_, self.discretization_bins, self.P.unif_range) c_next_ = self.T(from_numpy_to_var(c_)).data.cpu().numpy() c_next_ = np.clip(c_next_, self.P.unif_range[0] + 1e-6, self.P.unif_range[1] - 1e-6) c_next_d = discretize(c_next_, self.discretization_bins, self.P.unif_range) return c_next_d
def astar_plan(self, c_start, c_goal, verbose=True, **kwargs): """ Generate a plan in observation space given start and goal states via A* search. :param c_start: bs x c_dim :param c_goal: bs x c_dim :return: rollout: horizon x bs x channel_dim x img_W x img_H """ with torch.no_grad(): rollout = [] # _z = Variable(torch.randn(c_start.size()[0], self.rand_z_dim)).cuda() bs = c_start.size()[0] traj = plan_traj_astar( kwargs['start_obs'], kwargs['goal_obs'], start_state=c_start[0].data.cpu().numpy(), goal_state=c_goal[0].data.cpu().numpy(), transition_function=self.continuous_transition_function, preprocess_function=self.preprocess_function, discriminator_function=self.discriminator_function_np, generator_function=self.conditional_generator_function) for t, disc in enumerate(traj[:-1]): state = undiscretize(disc.state, self.discretization_bins, self.P.unif_range) state_next = undiscretize(traj[t + 1].state, self.discretization_bins, self.P.unif_range) c = from_numpy_to_var(state).repeat(bs, 1) c_next = from_numpy_to_var(state_next).repeat(bs, 1) _z = Variable(torch.randn(c.size()[0], self.rand_z_dim)).cuda() _cur_img, _next_img = self.G(_z, c, c_next) if t == 0: rollout.append(_cur_img) next_img = _next_img rollout.append(next_img) if verbose: # import ipdb; ipdb.set_trace() print("\t c_%d: %s" % (t, print_array(c[0].data))) return rollout