示例#1
0
    def _initialize(self):
        """Initialize non-common things."""
        # load demo replay memory
        with open(self.args.demo_path, "rb") as f:
            demo = list(pickle.load(f))

        # HER
        if self.hyper_params.use_her:
            self.her = build_her(self.hyper_params.her)
            print(f"[INFO] Build {str(self.her)}.")

            if self.hyper_params.desired_states_from_demo:
                self.her.fetch_desired_states_from_demo(demo)

            self.transitions_epi: list = list()
            self.desired_state = np.zeros((1, ))
            demo = self.her.generate_demo_transitions(demo)

            if not self.her.is_goal_in_state:
                self.state_dim = (self.state_dim[0] * 2, )
        else:
            self.her = None

        if not self.args.test:
            # Replay buffers
            demo_batch_size = self.hyper_params.demo_batch_size
            self.demo_memory = ReplayBuffer(len(demo), demo_batch_size)
            self.demo_memory.extend(demo)

            self.memory = ReplayBuffer(self.hyper_params.buffer_size,
                                       demo_batch_size)

            # set hyper parameters
            self.lambda2 = 1.0 / demo_batch_size
示例#2
0
    def _initialize(self):
        """Initialize non-common things."""
        # load demo replay memory
        with open(self.args.demo_path, "rb") as f:
            demo = list(pickle.load(f))

        # HER
        if self.hyper_params.use_her:
            self.her = build_her(self.hyper_params.her)
            print(f"[INFO] Build {str(self.her)}.")

            if self.hyper_params.desired_states_from_demo:
                self.her.fetch_desired_states_from_demo(demo)

            self.transitions_epi: list = list()
            self.desired_state = np.zeros((1, ))
            demo = self.her.generate_demo_transitions(demo)

            if not self.her.is_goal_in_state:
                self.state_dim = (self.state_dim[0] * 2, )
        else:
            self.her = None

        if not self.args.test:
            # Replay buffers
            demo_batch_size = self.hyper_params.demo_batch_size
            self.demo_memory = ReplayBuffer(len(demo), demo_batch_size)
            self.demo_memory.extend(demo)

            self.memory = ReplayBuffer(self.hyper_params.sac_buffer_size,
                                       demo_batch_size)

            # set hyper parameters
            self.hyper_params["lambda2"] = 1.0 / demo_batch_size

        self.args.cfg_path = self.args.offer_cfg_path
        self.args.load_from = self.args.load_offer_from
        self.hyper_params.buffer_size = self.hyper_params.sac_buffer_size
        self.hyper_params.batch_size = self.hyper_params.sac_batch_size

        self.learner_cfg.type = "BCSACLearner"
        self.learner_cfg.hyper_params = self.hyper_params

        self.learner = build_learner(self.learner_cfg)

        del self.hyper_params.buffer_size
        del self.hyper_params.batch_size

        # init stack
        self.stack_size = self.args.stack_size
        self.stack_buffer = deque(maxlen=self.args.stack_size)
        self.stack_buffer_2 = deque(maxlen=self.args.stack_size)

        self.scores = list()
        self.utilities = list()
        self.rounds = list()
        self.opp_utilities = list()
示例#3
0
    def _initialize(self):
        """Initialize non-common things."""
        # load demo replay memory
        with open(self.hyper_params.demo_path, "rb") as f:
            demo = list(pickle.load(f))

        # HER
        if self.hyper_params.use_her:
            self.her = build_her(self.hyper_params.her)
            print(f"[INFO] Build {str(self.her)}.")

            if self.hyper_params.desired_states_from_demo:
                self.her.fetch_desired_states_from_demo(demo)

            self.transitions_epi: list = list()
            self.desired_state = np.zeros((1, ))
            demo = self.her.generate_demo_transitions(demo)

            if not self.her.is_goal_in_state:
                self.env_info.observation_space.shape = (
                    self.self.env_info.observation_space.shape[0] * 2, )
        else:
            self.her = None

        if not self.is_test:
            # Replay buffers
            demo_batch_size = self.hyper_params.demo_batch_size
            self.demo_memory = ReplayBuffer(len(demo), demo_batch_size)
            self.demo_memory.extend(demo)

            self.memory = ReplayBuffer(self.hyper_params.buffer_size,
                                       self.hyper_params.batch_size)

            # set hyper parameters
            self.hyper_params["lambda2"] = 1.0 / demo_batch_size

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            noise_cfg=self.noise_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.shape[0],
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)