def make_direct_vid(self, separate_vid=False, resize=None):
        self.logger.log('making gif with tags')

        new_videolist = []
        for vid in self.video_list:
            print('key', vid[1])
            print('len', len(vid[0]))
            print('sizes', [im.shape for im in vid[0]])
            print('####')
            if 'gen_distrib' in vid[1]:
                plt.switch_backend('TkAgg')
                # plt.imshow(vid[0][0][0])
                # plt.show()

            images = vid[0]
            if resize is not None:
                images = resize_image(images, size=resize)
            name = vid[1]

            if images[0].shape[-1] == 1 or len(images[0].shape) == 3:
                images = color_code_distrib(images,
                                            self.numex,
                                            renormalize=True)

            new_videolist.append((images, name))

        framelist = assemble_gif(new_videolist,
                                 convert_from_float=True,
                                 num_exp=self.numex)
        # save_video_mp4(self.gif_savepath +'/prediction_at_t{}')
        npy_to_gif(
            framelist,
            self.gif_savepath + '/direct{}{}'.format(self.iternum, self.suf))
Ejemplo n.º 2
0
    def save_gif(self, itr, overlay=False):
        if self.traj_points is not None and overlay:
            colors = [
                tuple([np.random.randint(0, 256) for _ in range(3)])
                for __ in range(self.num_objects)
            ]
            for pnts, img in zip(self.traj_points, self.large_images_traj):
                for i in range(self.num_objects):
                    center = tuple([int(np.round(pnts[i, j])) for j in (1, 0)])
                    cv2.circle(img, center, 4, colors[i], -1)

        file_path = self._hyperparams['record']
        npy_to_gif(self.large_images_traj,
                   file_path + 'video{}'.format(itr),
                   fps=20)
    def reset_CEM_model(self):
        if len(self.imgs) > 0:
            print('saving iter', self.iter, 'with frames:', len(self.imgs))
            npy_to_gif(
                self.imgs,
                os.path.join(self.agentparams['record'],
                             'iter_{}'.format(self.iter)))
            self.iter += 1

        sim_state = self.CEM_model.get_state()
        sim_state.qpos[:] = self.initial_qpos.copy()
        sim_state.qvel[:] = self.initial_qvel.copy()
        self.CEM_model.set_state(sim_state)

        self.prev_target = self.CEM_model.data.qpos[:self.adim].squeeze().copy(
        )
        self.target = self.CEM_model.data.qpos[:self.adim].squeeze().copy()

        for _ in range(5):
            self.step_model(self.target)

        self.imgs = []
def make_direct_vid(dict, numex, gif_savepath, suf):
    """
    :param dict:  dictionary with video tensors of shape bsize, tlen, r, c, 3
    :param numex:
    :param gif_savepath:
    :param suf:
    :param resize:
    :return:
    """
    new_videolist = []
    shapes = []
    for key in dict:
        images = dict[key]
        print('key', key)
        print('shape', images.shape)

        if len(shapes) > 0:  # check that all the same size
            assert images.shape == shapes[-1], 'shape is different!'
        shapes.append(images.shape)
        assert not isinstance(images, list)

        # if 'gen_distrib' in vid[1]:
        #     plt.switch_backend('TkAgg')
        #     plt.imshow(vid[0][0][0])
        #     plt.show()

        if images[0].shape[-1] == 1 or len(images[0].shape) == 3:
            images = color_code_distrib(images, numex, renormalize=True)
        new_videolist.append((images, key))

    framelist = assemble_gif(new_videolist,
                             convert_from_float=True,
                             num_exp=numex)
    framelist.append(np.zeros_like(framelist[0]))
    # save_video_mp4(gif_savepath +'/prediction_at_t{}')
    npy_to_gif(framelist, gif_savepath + '/direct_{}'.format(suf))
    def reset(self, reset_state=None):
        """
        It's pretty important that we specify which reset functions to call
        instead of using super().reset() and self.reset()
           - That's because Demonstration policies use multiple inheritance to function and the recursive
             self.reset() results in pretty nasty errors. The pro to this approach is demonstration envs are easy to create
        """
        if reset_state is not None:
            self._read_reset_state = reset_state

        BaseMujocoEnv.reset(self)

        last_rands, write_reset_state = [], {}
        write_reset_state['reset_xml'] = copy.deepcopy(self._reset_xml)

        margin = 1.2 * self._maxlen
        if self._hp.verbose_dir is not None:
            print('resetting')

        if self._hp.verbose_dir is not None and len(self._verbose_vid) > 0:
            gif_num = np.random.randint(200000)
            npy_to_gif(
                self._verbose_vid, self._hp.verbose_dir +
                '/worker{}_verbose_traj_{}'.format(os.getpid(), gif_num), 20)
            self._verbose_vid = []

        def samp_xyz_rot():
            rand_xyz = np.random.uniform(
                low_bound[:3] + self._maxlen / 2 + 0.02,
                high_bound[:3] - self._maxlen / 2 + 0.02)
            rand_xyz[-1] = 0.05
            return rand_xyz, np.random.uniform(-np.pi / 2, np.pi / 2)

        object_poses = np.zeros((self.num_objects, 7))
        for i in range(self.num_objects):
            if self._read_reset_state is not None:
                obji_xyz = self._read_reset_state['object_qpos'][i][:3]
                obji_quat = self._read_reset_state['object_qpos'][i][3:]
            else:
                obji_xyz, rot = samp_xyz_rot()
                samp_cntr = 0
                #rejection sampling to ensure objects don't crowd each other
                while len(last_rands) > 0 and min([
                        np.linalg.norm(obji_xyz[:2] - obj_j[:2])
                        for obj_j in last_rands
                ]) < margin:
                    if samp_cntr >= 100:  # avoid infinite looping by generating new env
                        return BaseSawyerMujocoEnv.reset(self)
                    obji_xyz, rot = samp_xyz_rot()
                    samp_cntr += 1
                last_rands.append(obji_xyz)

                obji_quat = Quaternion(axis=[0, 0, -1], angle=rot).elements
            object_poses[i, :3] = obji_xyz
            object_poses[i, 3:] = obji_quat

        self.sim.data.set_mocap_pos('mocap', np.array([0, 0.5, 0.5]))
        self.sim.data.set_mocap_quat(
            'mocap',
            zangle_to_quat(np.random.uniform(low_bound[3], high_bound[3])))

        write_reset_state['object_qpos'] = copy.deepcopy(object_poses)
        object_poses = object_poses.reshape(-1)

        #placing objects then resetting to neutral risks bad contacts
        try:
            for s in range(5):
                self.sim.data.qpos[self._n_joints:] = object_poses.copy()
                self.sim.step()
            self.sim.data.qpos[:9] = NEUTRAL_JOINTS
            for _ in range(5):
                self.sim.step()
                if self._hp.verbose_dir is not None:
                    self._render_verbose()
        except MujocoException:
            return BaseSawyerMujocoEnv.reset(self)

        if self._read_reset_state is not None:
            end_eff_xyz = self._read_reset_state['state'][:3]
            end_eff_quat = zangle_to_quat(self._read_reset_state['state'][3])
        elif self.randomize_initial_pos:
            end_eff_xyz = np.random.uniform(low_bound[:3], high_bound[:3])
            while len(last_rands) > 0 and min([
                    np.linalg.norm(end_eff_xyz[:2] - obj_j[:2])
                    for obj_j in last_rands
            ]) < margin:
                end_eff_xyz = np.random.uniform(low_bound[:3], high_bound[:3])
            end_eff_quat = zangle_to_quat(
                np.random.uniform(low_bound[3], high_bound[3]))
        else:
            end_eff_xyz = np.array([0, 0.5, 0.17])
            end_eff_quat = zangle_to_quat(np.pi)

        write_reset_state['state'] = np.zeros(7)
        write_reset_state['state'][:3], write_reset_state['state'][
            3:] = end_eff_xyz.copy(), end_eff_quat.copy()

        finger_force = np.zeros(2)
        if self._hp.verbose_dir is not None:
            print('skip_first: {}'.format(self.skip_first))

        assert self.skip_first > 25, "Skip first should be at least 15"
        sim_state = self.sim.get_state()
        self.sim.data.qpos[:9] = NEUTRAL_JOINTS
        self.sim.data.qpos[self._n_joints:] = object_poses.copy()
        sim_state.qvel[:] = np.zeros_like(self.sim.data.qvel)
        self.sim.set_state(sim_state)

        for t in range(self.skip_first):
            if t < 20:
                if t < 5:
                    self.sim.data.qpos[self._n_joints:] = object_poses.copy()
                reset_xyz = (low_bound[:3] + high_bound[:3]) * 0.5
                reset_xyz[-1] = 0.4
                self.sim.data.set_mocap_pos('mocap', reset_xyz)
                self.sim.data.set_mocap_quat('mocap', zangle_to_quat(0))
                # reset gripper
                self.sim.data.qpos[7:9] = NEUTRAL_JOINTS[7:9]
                self.sim.data.ctrl[:] = [-1, 1]
            else:
                self.sim.data.set_mocap_pos('mocap', end_eff_xyz)
                self.sim.data.set_mocap_quat('mocap', end_eff_quat)
                # reset gripper
                self.sim.data.qpos[7:9] = NEUTRAL_JOINTS[7:9]
                self.sim.data.ctrl[:] = [-1, 1]

            if self._hp.verbose_dir is not None and t % 2 == 0:
                print('skip: {}'.format(t))
                self._render_verbose()

            for _ in range(20):
                self._clip_gripper()
                try:
                    self.sim.step()

                except MujocoException:
                    #if randomly generated start causes 'bad' contacts Mujoco will error. Have to reset again
                    print('except')
                    return BaseSawyerMujocoEnv.reset(self)

            if self.finger_sensors:
                finger_force += self.sim.data.sensordata[:2]
        if self._hp.verbose_dir is not None:
            print('after')
        finger_force /= 10 * self.skip_first

        self._previous_target_qpos = np.zeros(self._base_sdim)
        self._previous_target_qpos[:3] = self.sim.data.get_body_xpos('hand')
        self._previous_target_qpos[3] = quat_to_zangle(
            self.sim.data.get_body_xquat('hand'))
        self._previous_target_qpos[-1] = low_bound[-1]

        self._init_dynamics()

        if self._read_reset_state is not None:
            self._check_positions(end_eff_xyz, end_eff_quat, object_poses)

        obs, reset = self._get_obs(finger_force), write_reset_state
        obs['control_delta'] = np.zeros(4)
        return obs, write_reset_state
                        img_t = img_t[:, :, ::-1]
                    frame_imgs.append(img_t)
                img_summaries[i][int(summary_counter %
                                     args.im_per_row)].append(
                                         np.concatenate(frame_imgs, axis=1))
            summary_counter += 1

    if args.calc_deltas:
        delta_sums = np.array(delta_sums)
        adim = delta_sums.shape[-1]
        print('mean deltas: {}'.format(
            np.sum(np.sum(delta_sums, axis=0), axis=0) /
            (args.T * len(traj_names))))
        print('median delta: {}, max delta: {}'.format(
            np.median(delta_sums.reshape(-1, adim), axis=0),
            np.amax(delta_sums.reshape(-1, adim), axis=0)))
        tmaxs = np.argmax(delta_sums[:, :, -1], axis=-1)
        traj_max = np.argmax(delta_sums[np.arange(len(traj_names)), tmaxs, -1])
        print('max degree dif at traj: {}, t: {}'.format(
            traj_names[traj_max], tmaxs[traj_max]))

    print(' perc good: {}, and avg num failed rollouts: {}'.format(
        num_good / float(len(traj_names)), np.mean(rollout_fails)))

    if args.nimages > 0:
        img_summaries = [
            np.concatenate([np.concatenate(row, axis=0) for row in frame_t],
                           axis=1) for frame_t in img_summaries
        ]
        npy_to_gif(img_summaries, './summaries')