Ejemplo n.º 1
0
    def _setup_model(self) -> None:
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        self.rollout_buffer = RolloutBuffer(self.n_steps,
                                            self.observation_space,
                                            self.action_space,
                                            self.device,
                                            gamma=self.gamma,
                                            gae_lambda=self.gae_lambda,
                                            n_envs=self.n_envs)
        self.policy = self.policy_class(self.observation_space,
                                        self.action_space,
                                        self.lr_schedule,
                                        use_sde=self.use_sde,
                                        device=self.device,
                                        **self.policy_kwargs)
        self.policy = self.policy.to(self.device)

        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, (
                    '`clip_range_vf` must be positive, '
                    'pass `None` to deactivate vf clipping')

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
Ejemplo n.º 2
0
    def _setup_model(self) -> None:
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)
        # self.preprocessor = GraphProcessor(use_edges=self.use_edges, use_global=self.use_global) # TODO: Add graph preprocessor

        self.rollout_buffer = GraphRolloutBuffer(
            self.n_steps,
            self.device,  # Used to pass preprocessor
            gamma=self.gamma,
            gae_lambda=self.gae_lambda,
            n_envs=self.n_envs)
        self.policy = self.policy_class(self.observation_space,
                                        self.action_space,
                                        self.lr_schedule,
                                        use_sde=self.use_sde,
                                        device=self.device,
                                        **self.policy_kwargs)
        self.policy = self.policy.to(self.device)

        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, (
                    "`clip_range_vf` must be positive, "
                    "pass `None` to deactivate vf clipping")

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
Ejemplo n.º 3
0
    def _setup_model(self) -> None:
        super(PPO, self)._setup_model()

        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
Ejemplo n.º 4
0
    def _setup_model(self) -> None:
        # ActorCriticPolicy part
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        self.rollout_buffer = CNSPNSRolloutBuffer(
            self.n_steps,
            self.observation_space,
            self.action_space,
            self.device,
            gamma=self.gamma,
            gae_lambda=self.gae_lambda,
            n_envs=self.n_envs,
        )
        self.policy = self.policy_class(
            self.observation_space,
            self.action_space,
            self.lr_schedule,
            use_sde=self.use_sde,
            **self.policy_kwargs  # pytype:disable=not-instantiable
        )
        self.policy = self.policy.to(self.device)

        # PPO part
        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

        # CNSPNSPPO part
        # plug in adaptors for each body
        robot_ids = set()
        training_robot_ids = []
        for i in range(self.env.num_envs):
            robot_ids.add(self.env.envs[i].robot.robot_id)
            training_robot_ids.append(self.env.envs[i].robot.robot_id)
        self.policy.training_robot_ids = training_robot_ids
        self.policy.pns_sensor_adaptor.build_module_dict(
            robot_ids, self.observation_space.shape[0])
        self.policy.pns_motor_adaptor.build_module_dict(
            robot_ids, self.action_space.shape[0])
        self.policy.divide_and_use_different_learning_rates()

        self.policy = self.policy.to(self.device)
Ejemplo n.º 5
0
    def _setup_model(self) -> None:
        super(PPORepresentation, self)._setup_model()
        self.rollout_buffer = CustomBuffer(self.n_steps,
                                           self.observation_space,
                                           self.action_space,
                                           self.device,
                                           gamma=self.gamma,
                                           gae_lambda=self.gae_lambda,
                                           n_envs=self.n_envs)

        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, (
                    "`clip_range_vf` must be positive, "
                    "pass `None` to deactivate vf clipping")

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
Ejemplo n.º 6
0
    def _setup_model(self) -> None:
        # ActorCriticPolicy part
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        self.rollout_buffer = PNSRolloutBuffer(
            self.n_steps,
            self.observation_space,
            self.action_space,
            self.device,
            gamma=self.gamma,
            gae_lambda=self.gae_lambda,
            n_envs=self.n_envs,
        )
        self.policy = self.policy_class(
            self.observation_space,
            self.action_space,
            self.lr_schedule,
            use_sde=self.use_sde,
            **self.policy_kwargs  # pytype:disable=not-instantiable
        )
        self.policy = self.policy.to(self.device)

        # PPO part
        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

        # PNSPPO part
        for i in range(self.env.num_envs):
            self.policy.all_robot_ids.append(self.env.envs[i].robot.robot_id)

        if hasattr(self, "pns_senser_robot_id_2_idx"):
            for key, value in self.pns_senser_robot_id_2_idx.items():
                self.policy.features_extractor.robot_id_2_idx[int(key)] = value
        if hasattr(self, "pns_motor_robot_id_2_idx"):
            for key, value in self.pns_motor_robot_id_2_idx.items():
                self.policy.pns_motor_net.robot_id_2_idx[int(key)] = value
Ejemplo n.º 7
0
    def setup_PPO_model(self) -> None:

        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        self.rollout_buffer = RolloutBuffer(self.n_steps,
                                            self.observation_space,
                                            self.action_space,
                                            self.device,
                                            gamma=self.gamma,
                                            gae_lambda=self.gae_lambda,
                                            n_envs=self.n_envs)
        self.policy = self.policy.to(self.device)

        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, (
                    "`clip_range_vf` must be positive, "
                    "pass `None` to deactivate vf clipping")

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
Ejemplo n.º 8
0
    def _setup_model(self) -> None:
        self.aux_lr_schedule = get_schedule_fn(self.aux_learning_rate)
        self.policy_kwargs["aux_lr_schedule"] = self.aux_lr_schedule

        super(PPG, self)._setup_model()

        buffer_size = self.n_steps * self.n_envs * self.n_policy_iters
        self._observations_buffer = np.empty_like(
            self.rollout_buffer.observations,
            shape=(buffer_size, ) + self.rollout_buffer.obs_shape,
        )
        self._returns_buffer = np.empty_like(
            self.rollout_buffer.returns,
            shape=(buffer_size, 1),
        )
Ejemplo n.º 9
0
 def _setup_lr_schedule(self) -> None:
     """Transform to callable if needed."""
     self.lr_schedule = get_schedule_fn(self.learning_rate)
Ejemplo n.º 10
0
    def _setup_model(self) -> None:
        # ActorCriticPolicy part
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        self.rollout_buffer = PNSRolloutBuffer(
            self.n_steps,
            self.observation_space,
            self.action_space,
            self.device,
            gamma=self.gamma,
            gae_lambda=self.gae_lambda,
            n_envs=self.n_envs,
        )
        self.policy = self.policy_class(
            self.observation_space,
            self.action_space,
            self.lr_schedule,
            use_sde=self.use_sde,
            **self.policy_kwargs  # pytype:disable=not-instantiable
        )
        self.policy = self.policy.to(self.device)
        
        # PPO part
        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

        # PNSPPO part
        if common.args.pns_fix_cns:
            print("Fix all parameters in CNS.")
            for parameter in self.policy.parameters():
                parameter.requires_grad = False
            for i in range(self.n_envs):
                model = self.policy.features_extractor.pns[i]
                model.weight.requires_grad = True
            # for parameter in self.policy.parameters():
            #     print(parameter)

        for i in range(self.n_envs):
            model = self.policy.features_extractor.pns[i]
            weight = model.weight.detach().cpu().numpy()
            permutation_weight = permutation_matrix(weight)
            self.last_sensor_permutation_weights.append(permutation_weight)
            model = self.policy.pns_motor_net.pns[i]
            weight = model.weight.detach().cpu().numpy()
            permutation_weight = permutation_matrix(weight)
            self.last_motor_permutation_weights.append(permutation_weight)

        for i in range(self.env.num_envs):
            self.policy.all_robot_ids.append(self.env.envs[i].robot.robot_id)

        if hasattr(self, "pns_senser_robot_id_2_idx"):
            for key, value in self.pns_senser_robot_id_2_idx.items():
                self.policy.features_extractor.robot_id_2_idx[int(key)] = value
        if hasattr(self, "pns_motor_robot_id_2_idx"):
            for key, value in self.pns_motor_robot_id_2_idx.items():
                self.policy.pns_motor_net.robot_id_2_idx[int(key)] = value