Пример #1
0
    def test_environment_kwargs(self):
        env_kwargs = {
            'forward_reward_weight': 0.0,
            'ctrl_cost_weight': 0.0,
            'reset_noise_scale': 0.0,
        }

        env = GymAdapter(domain='Swimmer', task='v3', **env_kwargs)

        observation1, reward, done, info = env.step(env.action_space.sample())

        self.assertAlmostEqual(reward, 0.0)

        for key, expected_value in env_kwargs.items():
            actual_value = getattr(env.unwrapped, f'_{key}')
            self.assertEqual(actual_value, expected_value)
Пример #2
0
def flatten_multiworld_env(env):
    from multiworld.core.flat_goal_env import FlatGoalEnv
    flat_env = FlatGoalEnv(env,
                           obs_keys=['image_observation'],
                           goal_keys=['image_desired_goal'],
                           append_goal_to_obs=True)
    env = GymAdapter(env=flat_env)
    return env
Пример #3
0
    def test_rescale_observation(self):
        environment_kwargs = {
            'domain': 'MountainCar',
            'task': 'Continuous-v0',
        }
        environment = GymAdapter(**environment_kwargs)
        new_low, new_high = -1.0, 1.0

        assert isinstance(environment.env.observation_space, spaces.Box)
        assert np.any(environment.env.observation_space.low != new_low)
        assert np.any(environment.env.observation_space.high != new_high)

        rescaled_environment = GymAdapter(**environment_kwargs,
                                          rescale_observation_range=(new_low,
                                                                     new_high))

        np.testing.assert_allclose(
            rescaled_environment.env.observation_space.low, new_low)
        np.testing.assert_allclose(
            rescaled_environment.env.observation_space.high, new_high)
Пример #4
0
    def test_rescale_action(self):
        environment_kwargs = {
            'domain': 'Pendulum',
            'task': 'v0',
        }
        environment = GymAdapter(**environment_kwargs,
                                 rescale_action_range=None)
        new_low, new_high = -1.0, 1.0

        assert isinstance(environment.action_space, spaces.Box)
        assert np.any(environment.action_space.low != new_low)
        assert np.any(environment.action_space.high != new_high)

        rescaled_environment = GymAdapter(**environment_kwargs,
                                          rescale_action_range=(new_low,
                                                                new_high))

        np.testing.assert_allclose(rescaled_environment.action_space.low,
                                   new_low)
        np.testing.assert_allclose(rescaled_environment.action_space.high,
                                   new_high)
def main():
    pos_goals = [(0.01, 0.01), (-0.01, -0.01)]
    angle_goals = [180, 0]
    for goal_index, (angle_goal,
                     pos_goal) in enumerate(zip(angle_goals, pos_goals)):
        num_positives = 0
        NUM_TOTAL_EXAMPLES, ROLLOUT_LENGTH, STEPS_PER_SAMPLE = 200, 25, 4
        ANGLE_THRESHOLD, POSITION_THRESHOLD = 0.15, 0.035
        goal_radians = np.pi / 180. * angle_goal  # convert to radians
        observations = []
        images = True
        image_shape = (32, 32, 3)

        x, y = pos_goal
        env_kwargs = {
            'camera_settings': {
                'azimuth': 0,
                'distance': 0.32,
                'elevation': -45,
                'lookat': (0, 0, 0.03)
            },
            'pixel_wrapper_kwargs': {
                'pixels_only': False,
                'normalize': False,
                'render_kwargs': {
                    'width': image_shape[0],
                    'height': image_shape[1],
                    'camera_id': -1,
                }
            },

            # 'camera_settings': {
            #     'azimuth': 45.,
            #     'distance': 0.32,
            #     'elevation': -55.88,
            #     'lookat': np.array([0.00097442, 0.00063182, 0.03435371])
            # },
            # 'camera_settings': {
            #     'azimuth': 30.,
            #     'distance': 0.35,
            #     'elevation': -38.18,
            #     'lookat': np.array([0.00047, -0.0005, 0.054])
            # },
            'goals': ((x, y, 0, 0, 0, goal_radians), ),
            'goal_collection':
            True,
            'init_angle_range': (goal_radians - 0.05, goal_radians + 0.05),
            'target_angle_range': (goal_radians, goal_radians),
            'observation_keys':
            ('pixels', 'claw_qpos', 'last_action', 'goal_index'),
            'goal_completion_orientation_threshold':
            ANGLE_THRESHOLD,
            'goal_completion_position_threshold':
            POSITION_THRESHOLD,
        }

        env = GymAdapter(domain='DClaw',
                         task='TurnFreeValve3MultiGoal-v0',
                         **env_kwargs)

        path = directory + str(angle_goal)
        if not os.path.exists(path):
            os.makedirs(path)

        # reset the environment
        while num_positives <= NUM_TOTAL_EXAMPLES:
            observation = env.reset()
            print("Resetting environment...")
            t = 0
            while t < ROLLOUT_LENGTH:
                action = env.action_space.sample()
                for _ in range(STEPS_PER_SAMPLE):
                    observation, _, _, _ = env.step(action)

                # env.render()  # render on display
                # obs_dict = env.get_obs_dict()
                # print("OBS DICT:", obs_dict)

                if env.get_goal_completion():
                    # Add observation if meets criteria
                    # some hacky shit, find a better way to do this
                    observation['goal_index'] = np.array([goal_index])
                    observations.append(observation)
                    print(observation)
                    if images:
                        img_obs = observation['pixels']
                        imageio.imwrite(path + '/img%i.jpg' % num_positives,
                                        img_obs)
                    num_positives += 1
                t += 1

        goal_examples = {
            key: np.concatenate([obs[key][None] for obs in observations],
                                axis=0)
            for key in observations[0].keys()
        }

        with open(path + '/positives.pkl', 'wb') as file:
            pickle.dump(goal_examples, file)
Пример #6
0
 def create_adapter(self, domain='Swimmer', task='Default'):
     return GymAdapter(domain=domain, task=task)
def main():
    # goals = [np.array((0, 0)), np.array((-0.0875, 0.0875))]
    # Goal 0 = sides, Goal 1 = middle
    goals = [
        np.array([-0.0475, -0.0475, 0.0475, 0.0475]),
        np.array([0, 0, 0, 0])
    ]
    for goal_index, goal in enumerate(goals):
        num_positives = 0
        NUM_TOTAL_EXAMPLES, ROLLOUT_LENGTH, STEPS_PER_SAMPLE = 500, 25, 4
        POSITION_THRESHOLD = 0.015
        observations = []
        images = True
        image_shape = (32, 32, 3)

        goal_keys = (
            'pixels',
            'goal_index',
            'claw_qpos',
            'last_action',
            # 'object_position',
            # 'object_orientation_sin',
            # 'object_orientation_cos',
        )

        env_kwargs = {
            'pixel_wrapper_kwargs': {
                'pixels_only': False,
                'normalize': False,
                'render_kwargs': {
                    'width': image_shape[0],
                    'height': image_shape[1],
                    'camera_id': -1
                },
            },
            # 'camera_settings': {
            #     'azimuth': 23.234042553191497,
            #     'distance': 0.2403358053524018,
            #     'elevation': -29.68085106382978,
            #     'lookat': (-0.00390331,  0.01236683,  0.01093447),
            # },
            # Updated camera settings
            'camera_settings': {
                'azimuth': 90,
                'distance': 0.37,
                'elevation': -45,
                'lookat': (0, 0.046, -0.016),
            },
            'target_qpos_range': [goal],
            'observation_keys': goal_keys,
            'init_qpos_range': (goal - 0.01, goal + 0.01),
            # 'num_objects': 2,
            'num_objects': 4,
        }
        env = GymAdapter(domain='DClaw',
                         task='SlideBeadsFixed-v0',
                         **env_kwargs)

        path = directory + str(goal)
        if not os.path.exists(path):
            os.makedirs(path)

        # reset the environment
        while num_positives <= NUM_TOTAL_EXAMPLES:
            observation = env.reset()
            print("Resetting environment...")
            t = 0
            while t < ROLLOUT_LENGTH:
                action = env.action_space.sample()
                for _ in range(STEPS_PER_SAMPLE):
                    observation, _, _, _ = env.step(action)

                # env.render()  # render on display
                obs_dict = env.get_obs_dict()
                # print("OBS DICT:", obs_dict)

                # print(obs_dict['object_to_target_circle_distance'], obs_dict['object_to_target_position_distance'])
                if np.max(obs_dict['objects_to_targets_distances']
                          ) < POSITION_THRESHOLD:
                    # if obs_dict['object_to_target_circle_distance'] < ANGLE_THRESHOLD and obs_dict['object_to_target_position_distance'] < POSITION_THRESHOLD:
                    # Add observation if meets criteria
                    observation['goal_index'] = np.array([goal_index
                                                          ]).astype(np.float32)
                    observations.append(observation)
                    print(observation)
                    num_positives += 1

                    if images:
                        img_obs = observation['pixels']
                        imageio.imwrite(path + '/img%i.png' % num_positives,
                                        img_obs)

                t += 1

        goal_examples = {
            key: np.concatenate([obs[key][None] for obs in observations],
                                axis=0)
            for key in observations[0].keys()
        }

        with open(path + '/positives.pkl', 'wb') as file:
            pickle.dump(goal_examples, file)
def main():
    pos_goals = [(0, 0), (0, 0)]
    angle_goals = [0, 90]
    for goal_index, (angle_goal,
                     pos_goal) in enumerate(zip(angle_goals, pos_goals)):
        num_positives = 0
        NUM_TOTAL_EXAMPLES, ROLLOUT_LENGTH, STEPS_PER_SAMPLE = 200, 25, 4
        ANGLE_THRESHOLD, POSITION_THRESHOLD = 0.15, 0.035
        goal_radians = np.pi / 180. * angle_goal  # convert to radians
        observations = []
        images = True
        image_shape = (32, 32, 3)

        goal_keys = (
            'pixels',
            # 'object_position',
            # 'object_orientation_sin',
            # 'object_orientation_cos',
            'goal_index',
        )

        x, y = pos_goal
        env_kwargs = {
            'pixel_wrapper_kwargs': {
                'pixels_only': False,
                'normalize': False,
                'render_kwargs': {
                    'width': image_shape[0],
                    'height': image_shape[1],
                    'camera_id': -1
                },
            },
            'camera_settings': {
                'azimuth': 180,
                'distance': 0.26,
                'elevation': -40,
                'lookat': (0, 0, 0.06),
            },
            'goals': ((x, y, 0, 0, 0, goal_radians), ),
            'observation_keys':
            goal_keys,
            'goal_completion_position_threshold':
            POSITION_THRESHOLD,
            'goal_completion_orientation_threshold':
            ANGLE_THRESHOLD,
            'goal_collection':
            True,
            'init_qpos_range': ((0, 0, 0, 0, 0, goal_radians - 0.025),
                                (0, 0, 0, 0, 0, goal_radians + 0.025)),
            'target_qpos_range':
            [(pos_goal[0], pos_goal[1], 0, 0, 0, goal_radians)],
            'use_bowl_arena':
            True,
        }
        env = GymAdapter(domain='DClaw',
                         task='TurnFreeValve3MultiGoal-v0',
                         **env_kwargs)

        path = directory + str(angle_goal)
        if not os.path.exists(path):
            os.makedirs(path)

        # reset the environment
        while num_positives <= NUM_TOTAL_EXAMPLES:
            observation = env.reset()
            print("Resetting environment...")
            t = 0
            while t < ROLLOUT_LENGTH:
                action = env.action_space.sample()
                for _ in range(STEPS_PER_SAMPLE):
                    observation, _, _, _ = env.step(action)

                # env.render()  # render on display
                obs_dict = env.get_obs_dict()
                # print("OBS DICT:", obs_dict)

                # print(obs_dict['object_to_target_circle_distance'], obs_dict['object_to_target_position_distance'])
                if env.get_goal_completion():
                    # if obs_dict['object_to_target_circle_distance'] < ANGLE_THRESHOLD and obs_dict['object_to_target_position_distance'] < POSITION_THRESHOLD:
                    # Add observation if meets criteria
                    observation['goal_index'] = np.array([goal_index
                                                          ]).astype(np.float32)
                    observations.append(observation)
                    print(observation)
                    num_positives += 1

                    if images:
                        img_obs = observation['pixels']
                        imageio.imwrite(path + '/img%i.png' % num_positives,
                                        img_obs)

                t += 1

        goal_examples = {
            key: np.concatenate([obs[key][None] for obs in observations],
                                axis=0)
            for key in observations[0].keys()
        }

        with open(path + '/positives.pkl', 'wb') as file:
            pickle.dump(goal_examples, file)
Пример #9
0
 def verify_reset_and_step(domain, task):
     env = GymAdapter(domain=domain, task=task)
     env.reset()
     env.step(env.action_space.sample())
Пример #10
0
 def create_adapter(self, domain='Swimmer', task='v3', *args, **kwargs):
     return GymAdapter(domain, task, *args, **kwargs)
Пример #11
0
def main():
    mixed_goal_pool = False
    one_hot_goal_index = False
    images = True
    goals = [90, -90]
    num_goals = len(goals)

    image_shape = (32, 32, 3)
    NUM_TOTAL_EXAMPLES, ROLLOUT_LENGTH, STEPS_PER_SAMPLE = 500, 25, 5
    observations = []

    for goal_index, goal in enumerate(goals):
        if not mixed_goal_pool:
            observations = []  # reset the observations

        num_positives = 0
        goal_angle = np.pi / 180. * goal  # convert to radians

        env_kwargs = {
            'camera_settings': {
                'azimuth': 0.,
                'distance': 0.32,
                'elevation': -45,
                'lookat': np.array([0, 0, 0.06])
            },
            'goals': (goal_angle, ),
            'goal_collection':
            True,
            'init_object_pos_range': (goal_angle - 0.05, goal_angle + 0.05),
            'target_pos_range': (goal_angle, goal_angle),
            'pixel_wrapper_kwargs': {
                'pixels_only': False,
                'normalize': False,
                'render_kwargs': {
                    'width': image_shape[0],
                    'height': image_shape[1],
                    'camera_id': -1
                },
            },
            'swap_goals_upon_completion':
            True,
            'observation_keys':
            ('pixels', 'claw_qpos', 'last_action', 'goal_index'),
        }
        env = GymAdapter(domain='DClaw', task='TurnMultiGoal-v0', **env_kwargs)

        if mixed_goal_pool:
            path = directory
        else:
            path = os.path.join(directory, str(goal))
        if not os.path.exists(path):
            os.makedirs(path)

        # reset the environment
        while num_positives <= NUM_TOTAL_EXAMPLES:
            observation = env.reset()
            print("Resetting environment...")
            t = 0
            while t < ROLLOUT_LENGTH:
                action = env.action_space.sample()
                for _ in range(STEPS_PER_SAMPLE):
                    observation, _, _, _ = env.step(action)

                #env.render()  # render on display
                obs_dict = env.get_obs_dict()
                # print("OBS DICT:", obs_dict)

                # For fixed screw
                object_target_angle_dist = obs_dict[
                    'object_to_target_angle_dist']

                ANGLE_THRESHOLD = 0.15
                if object_target_angle_dist < ANGLE_THRESHOLD:
                    # Add observation if meets criteria
                    if one_hot_goal_index:
                        one_hot = np.zeros(num_goals).astype(np.float32)
                        one_hot[goal_index] = 1.
                        observation['goal_index'] = one_hot
                    else:
                        observation['goal_index'] = np.array([goal_index])
                    observations.append(observation)
                    print(observation)
                    if images:
                        img_obs = observation['pixels']
                        imageio.imwrite(
                            path + f'/img_{goal}_{num_positives}.jpg', img_obs)
                    num_positives += 1
                t += 1

        goal_examples = {
            key: np.concatenate([obs[key][None] for obs in observations],
                                axis=0)
            for key in observations[0].keys()
        }

        with open(path + '/positives.pkl', 'wb') as file:
            pickle.dump(goal_examples, file)
Пример #12
0
def main():
    num_positives = 0
    NUM_TOTAL_EXAMPLES, ROLLOUT_LENGTH, STEPS_PER_SAMPLE = 250, 25, 4
    goal_angle = np.pi
    observations = []
    images = True
    image_shape = (32, 32, 3)

    env_kwargs = {
        'pixel_wrapper_kwargs': {
            'pixels_only': False,
            'normalize': False,
            'render_kwargs': {
                'width': image_shape[0],
                'height': image_shape[1],
                'camera_id': -1,
            },
        },
        'camera_settings': {
            'azimuth': 0.,
            'distance': 0.35,
            'elevation': -38.17570837642188,
            'lookat': np.array([0.00046945, -0.00049496, 0.05389398]),
        },
        'init_pos_range': (goal_angle - 0.05, goal_angle + 0.05),
        'target_pos_range': (goal_angle, goal_angle),
        'observation_keys': (
            'pixels',
            'claw_qpos',
            'last_action',
            'object_xy_position',
            'object_z_orientation_cos',
            'object_z_orientation_sin',
        ),
    }
    env = GymAdapter(
        domain='DClaw',
        task='TurnFixed-v0',
        **env_kwargs
    )

    ANGLE_THRESHOLD = 0.15
    goal_criteria = lambda angle_dist: angle_dist < ANGLE_THRESHOLD

    # reset the environment
    while num_positives <= NUM_TOTAL_EXAMPLES:
        observation = env.reset()
        print("Resetting environment...")
        t = 0
        while t < ROLLOUT_LENGTH:
            action = env.action_space.sample()
            for _ in range(STEPS_PER_SAMPLE):
                observation, _, _, _ = env.step(action)

            # env.render()  # render on display
            obs_dict = env.get_obs_dict()

            circle_dist = obs_dict['object_to_target_angle_distance']
            print(f"Circle dist: {circle_dist}")

            if goal_criteria(circle_dist):
                # Add observation if meets criteria
                observations.append(observation)
                print(observation)
                if images:
                    img_obs = observation['pixels']
                    # image = img_obs[:np.prod(image_shape)].reshape(image_shape)
                    img_obs = 255 / 2 * (img_obs + 1)
                    imageio.imwrite(directory + '/img%i.png' % num_positives, img_obs)
                num_positives += 1
            t += 1

    goal_examples = {
        key: np.concatenate([
            obs[key][None] for obs in observations
        ], axis=0)
        for key in observations[0].keys()
    }

    with open(directory + '/positives.pkl', 'wb') as file:
        pickle.dump(goal_examples, file)
Пример #13
0
    def _build(self):
        variant = copy.deepcopy(self._variant)

        #training_environment = self.training_environment = (
        #    get_goal_example_environment_from_variant(
        #        variant['task'], gym_adapter=False))

        training_environment = self.training_environment = (GymAdapter(
            domain=variant['domain'],
            task=variant['task'],
            **variant['env_params']))

        #evaluation_environment = self.evaluation_environment = (
        #    get_goal_example_environment_from_variant(
        #        variant['task_evaluation'], gym_adapter=False))
        evaluation_environment = self.evaluation_environment = (GymAdapter(
            domain=variant['domain'],
            task=variant['task_evaluation'],
            **variant['env_params']))

        # training_environment = self.training_environment = (
        #     flatten_multiworld_env(self.training_environment))
        # evaluation_environment = self.evaluation_environment = (
        #     flatten_multiworld_env(self.evaluation_environment))
        #training_environment = self.training_environment = (
        #        GymAdapter(env=training_environment))
        #evaluation_environment = self.evaluation_environment = (
        #        GymAdapter(env=evaluation_environment))

        # make sure this is her replay pool
        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            variant, training_environment))
        sampler = self.sampler = get_sampler_from_variant(variant)
        Qs = self.Qs = get_Q_function_from_variant(variant,
                                                   training_environment)
        policy = self.policy = get_policy_from_variant(variant,
                                                       training_environment)
        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy_from_params(variant['exploration_policy_params'],
                                   training_environment))

        algorithm_kwargs = {
            'variant': self._variant,
            'training_environment': self.training_environment,
            'evaluation_environment': self.evaluation_environment,
            'policy': policy,
            'initial_exploration_policy': initial_exploration_policy,
            'Qs': Qs,
            'pool': replay_pool,
            'sampler': sampler,
            'session': self._session,
        }

        if self._variant['algorithm_params']['type'] in [
                'VICEGoalConditioned', 'VICEGANGoalConditioned'
        ]:
            reward_classifier = self.reward_classifier = (
                get_reward_classifier_from_variant(self._variant,
                                                   training_environment))
            algorithm_kwargs['classifier'] = reward_classifier

            # goal_examples_train, goal_examples_validation = \
            #     get_goal_example_from_variant(variant)
            algorithm_kwargs['goal_examples'] = np.empty((1, 1))
            algorithm_kwargs['goal_examples_validation'] = np.empty((1, 1))

        # RND
        if variant['algorithm_params']['rnd_params']:
            from softlearning.rnd.utils import get_rnd_networks_from_variant
            rnd_networks = get_rnd_networks_from_variant(
                variant, training_environment)
        else:
            rnd_networks = ()
        algorithm_kwargs['rnd_networks'] = rnd_networks

        self.algorithm = get_algorithm_from_variant(**algorithm_kwargs)

        initialize_tf_variables(self._session, only_uninitialized=True)

        self._built = True
def main():
    num_positives = 0
    NUM_TOTAL_EXAMPLES, ROLLOUT_LENGTH, STEPS_PER_SAMPLE = 250, 25, 4
    goal_angle = np.pi
    observations = []
    images = True
    image_shape = (32, 32, 3)

    env_kwargs = {
        'pixel_wrapper_kwargs': {
            'pixels_only': False,
            'normalize': False,
            'render_kwargs': {
                'width': image_shape[0],
                'height': image_shape[1],
                'camera_id': -1,
            },
        },
        # 'camera_settings': {
        #     'distance': 0.5,
        #     'elevation': -60
        # },
        'camera_settings': {
            'azimuth': 180,
            'distance': 0.35,
            'elevation': -55,
            'lookat': np.array([0, 0, 0.03]),
        },
        'init_qpos_range': ((0, 0, 0, 0, 0, goal_angle - 0.05),
                            (0, 0, 0, 0, 0, goal_angle + 0.05)),
        'target_qpos_range':
        ((0, 0, 0, 0, 0, goal_angle), (0, 0, 0, 0, 0, goal_angle)),
        'observation_keys':
        ('pixels', 'claw_qpos', 'last_action', 'object_xy_position',
         'object_z_orientation_cos', 'object_z_orientation_sin'),
    }
    env = GymAdapter(domain='DClaw',
                     task='TurnFreeValve3Fixed-v0',
                     **env_kwargs)

    ANGLE_THRESHOLD, POSITION_THRESHOLD = 0.15, 0.035
    goal_criteria = lambda angle_dist, pos_dist: angle_dist < ANGLE_THRESHOLD \
        and pos_dist < POSITION_THRESHOLD

    # reset the environment
    while num_positives <= NUM_TOTAL_EXAMPLES:
        observation = env.reset()
        print("Resetting environment...")
        t = 0
        while t < ROLLOUT_LENGTH:
            action = env.action_space.sample()
            for _ in range(STEPS_PER_SAMPLE):
                observation, _, _, _ = env.step(action)
            # env.render()  # render on display
            obs_dict = env.get_obs_dict()

            circle_dist = obs_dict['object_to_target_circle_distance']
            pos_dist = obs_dict['object_to_target_position_distance']
            print(f"Circle dist: {circle_dist}, Position dist: {pos_dist}")

            if goal_criteria(circle_dist, pos_dist):
                # Add observation if meets criteria
                observations.append(observation)
                print(observation)
                if images:
                    img_obs = observation['pixels']
                    # img_0, img_1 = np.split(
                    #     img_obs,
                    #     indices_or_sections=2,
                    #     axis=2
                    # )
                    # concat_obs = np.concatenate([img_0, img_1], axis=1)
                    skimage.io.imsave(directory + f'/img_{num_positives}.png',
                                      img_obs)
                num_positives += 1
            t += 1

    goal_examples = {
        key: np.concatenate([obs[key][None] for obs in observations], axis=0)
        for key in observations[0].keys()
    }

    with open(directory + '/positives.pkl', 'wb') as file:
        pickle.dump(goal_examples, file)
Пример #15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num-trajectories',
                        type=int,
                        default=500,
                        help='Number of trajectories to collect')
    parser.add_argument('--save-path-name',
                        type=str,
                        default='screw_data',
                        help='Save directory name')
    parser.add_argument('--rollout-length',
                        type=int,
                        default=25,
                        help='Number of timesteps per rollout')
    parser.add_argument('--dump-frequency',
                        type=int,
                        default=100,
                        help='Number of trajectories per dump')
    parser.add_argument('--image-shape',
                        type=lambda x: eval(x),
                        default=(32, 32, 3),
                        help='(width, height, channels) to save for pixels')
    parser.add_argument('--save-images',
                        type=lambda x: eval(x),
                        default=False,
                        help='Whether or not to save images while collecting')
    parser.add_argument(
        '--split-trajectories',
        type=lambda x: eval(x),
        default=False,
        help='Whether or not to split by trajectory during collection')
    parser.add_argument('--task',
                        type=str,
                        default='TurnFreeValve3Fixed-v0',
                        help='Task to collect data for')

    args = parser.parse_args()

    cur_dir = os.path.dirname(os.path.realpath(__file__))
    directory = os.path.join(cur_dir, args.save_path_name)

    if not os.path.exists(directory):
        os.makedirs(directory)

    NUM_TOTAL_TRAJECTORIES = args.num_trajectories
    ROLLOUT_LENGTH = args.rollout_length
    DUMP_FREQUENCY = args.dump_frequency

    image_shape = args.image_shape
    save_images = args.save_images

    should_split_trajectories = args.split_trajectories

    trajectories = ([] if should_split_trajectories else {})
    # trajectories_since_last_dump = []

    task = args.task
    env_kwargs = get_environment_params(task, image_shape)
    env = GymAdapter(domain='DClaw', task=task, **env_kwargs)

    # reset the environment
    for n_trajectory in range(NUM_TOTAL_TRAJECTORIES):
        # All the observations keys will be added below
        trajectory = {
            'actions': [],
            'states': [],
        }
        env.reset()

        t = 0
        while t < ROLLOUT_LENGTH:
            # 1. Collect and perform actions (sampled uniformly)
            action = env.action_space.sample()

            # 2. Collect observations (including claw position and pixels)
            observation, _, _, _ = env.step(action)
            obs_dict = env.get_obs_dict()

            # Add to all the observation keys
            for k, v in observation.items():
                if k not in trajectory:
                    trajectory[k] = []
                trajectory[k].append(v)
            trajectory['actions'].append(action)

            # 3. Calculate the ground truth state
            if task == 'TurnFreeValve3Fixed-v0':
                xy = normalize(obs_dict['object_xy_position'], -0.1, 0.1, -1,
                               1)
                cos, sin = (obs_dict['object_orientation_cos'][2],
                            obs_dict['object_orientation_sin'][2])
                state = np.concatenate([xy, cos[None], sin[None]])
            elif task == 'TurnFixed-v0':
                state = np.concatenate([
                    observation['object_angle_cos'],
                    observation['object_angle_sin']
                ])
            elif task == 'SlideBeadsFixed-v0':
                state = obs_dict['objects_positions']

            trajectory['states'].append(state)

            # Save an image if the flag is True
            if save_images:
                skimage.io.imsave(os.path.join(directory, f'img{t}.png'),
                                  observation['pixels'])
            t += 1

        if should_split_trajectories:
            # Concat everything nicely
            for k, v in trajectory.items():
                trajectory[k] = np.stack(v)

            trajectories.append(trajectory)
            # trajectories_since_last_dump.append(trajectory)
        else:
            for k, v in trajectory.items():
                trajectories[k] = trajectories.get(k, []) + trajectory[k]

        if n_trajectory % 20 == 0:
            print(f"\n{n_trajectory} trajectories collected...")

        # if n_trajectory > 0 and n_trajectory % DUMP_FREQUENCY == 0:
        #     print('DUMPING DATA... total # trajectories:', n_trajectory)
        #     with gzip.open(os.path.join(directory, f'data_{n_trajectory}.pkl'), 'wb') as f:
        #         pickle.dump(trajectories_since_last_dump, f)
        #     trajectories_since_last_dump = []

    # Save everything

    # TODO: Fix to make one dictionary of np arrays instead of one array of many dicts
    if not should_split_trajectories:
        for k, v in trajectories.items():
            trajectories[k] = np.stack(v)

    with gzip.open(os.path.join(directory, f'data.pkl'), 'wb') as f:
        pickle.dump(trajectories, f)
def main():
    num_positives = 0
    NUM_TOTAL_EXAMPLES, ROLLOUT_LENGTH, STEPS_PER_SAMPLE = 50, 25, 5
    goal_angle = np.pi
    observations = []
    images = True
    image_shape = (32, 32, 3)

    env_kwargs = {
        'camera_settings': {
            'azimuth': 0.,
            'distance': 0.32,
            'elevation': -44.72107438016526,
            'lookat': np.array([0.00815854, -0.00548645, 0.08652757])
        },
        'goals': (np.pi, ),
        'goal_collection': True,
        'init_object_pos_range': (goal_angle - 0.05, goal_angle + 0.05),
        'target_pos_range': (goal_angle, goal_angle),
        'pixel_wrapper_kwargs': {
            'pixels_only': False,
            'render_kwargs': {
                'width': 32,
                'height': 32,
                'camera_id': -1
            }
        },
        'swap_goals_upon_completion': True,
        'observation_keys':
        ('pixels', 'claw_qpos', 'last_action',
         'goal_index'),  # save goal index to mask in the classifier
    }
    env = GymAdapter(domain='DClaw', task='TurnMultiGoal-v0', **env_kwargs)

    # reset the environment
    while num_positives <= NUM_TOTAL_EXAMPLES:
        observation = env.reset()
        print("Resetting environment...")
        t = 0
        while t < ROLLOUT_LENGTH:
            action = env.action_space.sample()
            for _ in range(STEPS_PER_SAMPLE):
                observation, _, _, _ = env.step(action)

            #env.render()  # render on display
            obs_dict = env.get_obs_dict()
            # print("OBS DICT:", obs_dict)

            # For fixed screw
            object_target_angle_dist = obs_dict['object_to_target_angle_dist']

            ANGLE_THRESHOLD = 0.15
            if object_target_angle_dist < ANGLE_THRESHOLD:
                # Add observation if meets criteria
                observation['goal_index'] = np.array(
                    [1])  # some hacky shit, find a better way to do this
                observations.append(observation)
                print(observation)
                if images:
                    img_obs = observation['pixels']
                    image = img_obs[:np.prod(image_shape)].reshape(image_shape)
                    print_image = (image + 1.) * 255. / 2.
                    imageio.imwrite(directory + '/img%i.jpg' % num_positives,
                                    print_image)
                num_positives += 1
            t += 1

    goal_examples = {
        key: np.concatenate([obs[key][None] for obs in observations], axis=0)
        for key in observations[0].keys()
    }

    with open(directory + '/positives.pkl', 'wb') as file:
        pickle.dump(goal_examples, file)