コード例 #1
0
 def conditional_generator_function(self, c_, c_next_, obs):
     '''
     This doesn't do anything.
     '''
     c_ = undiscretize(c_, self.discretization_bins, self.P.unif_range)
     c_next_ = undiscretize(c_next_, self.discretization_bins, self.P.unif_range)
     z_ = from_numpy_to_var(np.random.randn(c_.shape[0], self.rand_z_dim))
     _, next_observation = self.G(z_, from_numpy_to_var(c_), from_numpy_to_var(c_next_))
     return next_observation.data.cpu().numpy()
コード例 #2
0
 def continuous_transition_function(self, c_):
     c_ = undiscretize(c_, self.discretization_bins, self.P.unif_range)
     c_next_ = self.T(from_numpy_to_var(c_)).data.cpu().numpy()
     c_next_ = np.clip(c_next_, self.P.unif_range[0] + 1e-6,
                       self.P.unif_range[1] - 1e-6)
     c_next_d = discretize(c_next_, self.discretization_bins,
                           self.P.unif_range)
     return c_next_d
コード例 #3
0
    def astar_plan(self, c_start, c_goal, verbose=True, **kwargs):
        """
        Generate a plan in observation space given start and goal states via A* search.
        :param c_start: bs x c_dim
        :param c_goal: bs x c_dim
        :return: rollout: horizon x bs x channel_dim x img_W x img_H
        """
        with torch.no_grad():
            rollout = []
            # _z = Variable(torch.randn(c_start.size()[0], self.rand_z_dim)).cuda()
            bs = c_start.size()[0]
            traj = plan_traj_astar(
                kwargs['start_obs'],
                kwargs['goal_obs'],
                start_state=c_start[0].data.cpu().numpy(),
                goal_state=c_goal[0].data.cpu().numpy(),
                transition_function=self.continuous_transition_function,
                preprocess_function=self.preprocess_function,
                discriminator_function=self.discriminator_function_np,
                generator_function=self.conditional_generator_function)

            for t, disc in enumerate(traj[:-1]):
                state = undiscretize(disc.state, self.discretization_bins,
                                     self.P.unif_range)
                state_next = undiscretize(traj[t + 1].state,
                                          self.discretization_bins,
                                          self.P.unif_range)
                c = from_numpy_to_var(state).repeat(bs, 1)
                c_next = from_numpy_to_var(state_next).repeat(bs, 1)
                _z = Variable(torch.randn(c.size()[0], self.rand_z_dim)).cuda()

                _cur_img, _next_img = self.G(_z, c, c_next)
                if t == 0:
                    rollout.append(_cur_img)
                next_img = _next_img
                rollout.append(next_img)
                if verbose:
                    # import ipdb; ipdb.set_trace()
                    print("\t c_%d: %s" % (t, print_array(c[0].data)))
        return rollout