예제 #1
0
    def load(self, t=None):
        """Load."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.qf1.load_state_dict(state_dict['qf1'])
        self.qf2.load_state_dict(state_dict['qf2'])
        self.target_qf1.load_state_dict(state_dict['target_qf1'])
        self.target_qf2.load_state_dict(state_dict['target_qf2'])

        self.opt_pi.load_state_dict(state_dict['opt_pi'])
        self.opt_qf1.load_state_dict(state_dict['opt_qf1'])
        self.opt_qf2.load_state_dict(state_dict['opt_qf2'])

        if state_dict['log_alpha']:
            with torch.no_grad():
                self.log_alpha.copy_(state_dict['log_alpha'])
            self.opt_alpha.load_state_dict(state_dict['opt_alpha'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(
            np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
                    allow_pickle=True))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(
            nest.pack_sequence_as(buffer_state, buffer_format))
        self.data_manager.manual_reset()
        return self.t
예제 #2
0
 def load(self, t=None):
     """Load state dict."""
     state_dict = self.ckptr.load(t)
     if state_dict is None:
         self.t = 0
         return self.t
     self.pi.load_state_dict(state_dict['pi'])
     self.opt.load_state_dict(state_dict['opt'])
     misc.env_load_state_dict(self.env, state_dict['env'])
     self.t = state_dict['t']
     return self.t
예제 #3
0
 def load(self, t=None):
     """Load state dict."""
     state_dict = self.ckptr.load(t)
     if state_dict is None:
         self.t = 0
         return self.t
     self.pi.load_state_dict(state_dict['pi'])
     self.opt.load_state_dict(state_dict['opt'])
     self.opt_l.load_state_dict(state_dict['opt_l'])
     self.log_lambda_.data.copy_(state_dict['lambda_'])
     misc.env_load_state_dict(self.env, state_dict['env'])
     self._actor.load_state_dict(state_dict['_actor'])
     self.t = state_dict['t']
     return self.t
 def load(self, t=None):
     """Load state dict."""
     state_dict = self.ckptr.load(t)
     if state_dict is None:
         self.t = 0
         return self.t
     self.pi.load_state_dict(state_dict['pi'])
     self.vf.load_state_dict(state_dict['vf'])
     self.opt_pi.load_state_dict(state_dict['opt_pi'])
     self.opt_vf.load_state_dict(state_dict['opt_vf'])
     self.kl_weight = state_dict['kl_weight']
     misc.env_load_state_dict(self.env, state_dict['env'])
     self._actor.load_state_dict(state_dict['_actor'])
     self.t = state_dict['t']
     return self.t
예제 #5
0
    def load(self, t=None):
        """Load."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.qf.load_state_dict(state_dict['qf'])
        self.qf_targ.load_state_dict(state_dict['qf_targ'])
        self.opt.load_state_dict(state_dict['opt'])
        self._actor.load_state_dict(state_dict['_actor'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(
            np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz')))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(
            nest.pack_sequence_as(buffer_state, buffer_format))
        self.data_manager.manual_reset()
        return self.t