예제 #1
0
class SacAgent(BaseAgent):
    """TO BE DEPRECATED."""

    def __init__(
            self,
            ModelCls=PiMlpModel,  # Pi model.
            QModelCls=QofMuMlpModel,
            VModelCls=VMlpModel,
            model_kwargs=None,  # Pi model.
            q_model_kwargs=None,
            v_model_kwargs=None,
            initial_model_state_dict=None,  # All models.
            action_squash=1.,  # Max magnitude (or None).
            pretrain_std=0.75,  # With squash 0.75 is near uniform.
            ):
        if model_kwargs is None:
            model_kwargs = dict(hidden_sizes=[256, 256])
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[256, 256])
        if v_model_kwargs is None:
            v_model_kwargs = dict(hidden_sizes=[256, 256])
        super().__init__(ModelCls=ModelCls, model_kwargs=model_kwargs,
            initial_model_state_dict=initial_model_state_dict)
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

    def initialize(self, env_spaces, share_memory=False,
            global_B=1, env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        self.initial_model_state_dict = None  # Don't let base agent try to load.
        super().initialize(env_spaces, share_memory,
            global_B=global_B, env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict
        self.q1_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
        self.q2_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
        self.v_model = self.VModelCls(**self.env_model_kwargs, **self.v_model_kwargs)
        self.target_v_model = self.VModelCls(**self.env_model_kwargs,
            **self.v_model_kwargs)
        self.target_v_model.load_state_dict(self.v_model.state_dict())
        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)
        self.q1_model.to(self.device)
        self.q2_model.to(self.device)
        self.v_model.to(self.device)
        self.target_v_model.to(self.device)

    def data_parallel(self):
        super().data_parallel
        DDP_WRAP = DDPC if self.device.type == "cpu" else DDP
        self.q1_model = DDP_WRAP(self.q1_model)
        self.q2_model = DDP_WRAP(self.q2_model)
        self.v_model = DDP_WRAP(self.v_model)

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    def make_env_to_model_kwargs(self, env_spaces):
        assert len(env_spaces.action.shape) == 1
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.shape[0],
        )

    def q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to((observation, prev_action, prev_reward,
            action), device=self.device)
        q1 = self.q1_model(*model_inputs)
        q2 = self.q2_model(*model_inputs)
        return q1.cpu(), q2.cpu()

    def v(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        v = self.v_model(*model_inputs)
        return v.cpu()

    def pi(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        mean, log_std = self.model(*model_inputs)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action, log_pi = self.distribution.sample_loglikelihood(dist_info)
        # action = self.distribution.sample(dist_info)
        # log_pi = self.distribution.log_likelihood(action, dist_info)
        log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
        return action, log_pi, dist_info  # Action stays on device for q models.

    def target_v(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        target_v = self.target_v_model(*model_inputs)
        return target_v.cpu()

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        mean, log_std = self.model(*model_inputs)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_v_model, self.v_model.state_dict(), tau)

    @property
    def models(self):
        return Models(pi=self.model, q1=self.q1_model, q2=self.q2_model,
            v=self.v_model)

    def pi_parameters(self):
        return self.model.parameters()

    def q1_parameters(self):
        return self.q1_model.parameters()

    def q2_parameters(self):
        return self.q2_model.parameters()

    def v_parameters(self):
        return self.v_model.parameters()

    def train_mode(self, itr):
        super().train_mode(itr)
        self.q1_model.train()
        self.q2_model.train()
        self.v_model.train()

    def sample_mode(self, itr):
        super().sample_mode(itr)
        self.q1_model.eval()
        self.q2_model.eval()
        self.v_model.eval()
        if itr == 0:
            logger.log(f"Agent at itr {itr}, sample std: {self.pretrain_std}")
        if itr == self.min_itr_learn:
            logger.log(f"Agent at itr {itr}, sample std: learned.")
        std = None if itr >= self.min_itr_learn else self.pretrain_std
        self.distribution.set_std(std)  # If None: std from policy dist_info.

    def eval_mode(self, itr):
        super().eval_mode(itr)
        self.q1_model.eval()
        self.q2_model.eval()
        self.v_model.eval()
        self.distribution.set_std(0.)  # Deterministic (dist_info std ignored).

    def state_dict(self):
        return dict(
            model=self.model.state_dict(),  # Pi model.
            q1_model=self.q1_model.state_dict(),
            q2_model=self.q2_model.state_dict(),
            v_model=self.v_model.state_dict(),
            target_v_model=self.target_v_model.state_dict(),
        )

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.q1_model.load_state_dict(state_dict["q1_model"])
        self.q2_model.load_state_dict(state_dict["q2_model"])
        self.v_model.load_state_dict(state_dict["v_model"])
        self.target_v_model.load_state_dict(state_dict["target_v_model"])
예제 #2
0
class DdpgAgent(BaseAgent):
    """Agent for deep deterministic policy gradient algorithm."""

    shared_mu_model = None

    def __init__(
        self,
        ModelCls=MuMlpModel,  # Mu model.
        QModelCls=QofMuMlpModel,
        model_kwargs=None,  # Mu model.
        q_model_kwargs=None,
        initial_model_state_dict=None,  # Mu model.
        initial_q_model_state_dict=None,
        action_std=0.1,
        action_noise_clip=None,
    ):
        """Saves input arguments; default network sizes saved here."""
        if model_kwargs is None:
            model_kwargs = dict(hidden_sizes=[400, 300])
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[400, 300])
        save__init__args(locals())
        super().__init__()  # For async setup.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        """Instantiates mu and q, and target_mu and target_q models."""
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.q_model = self.QModelCls(**self.env_model_kwargs,
                                      **self.q_model_kwargs)
        if self.initial_q_model_state_dict is not None:
            self.q_model.load_state_dict(self.initial_q_model_state_dict)
        self.target_model = self.ModelCls(**self.env_model_kwargs,
                                          **self.model_kwargs)
        self.target_q_model = self.QModelCls(**self.env_model_kwargs,
                                             **self.q_model_kwargs)
        self.target_q_model.load_state_dict(self.q_model.state_dict())
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            std=self.action_std,
            noise_clip=self.action_noise_clip,
            clip=env_spaces.action.high[0],  # Assume symmetric low=-high.
        )

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)  # Takes care of self.model.
        self.target_model.to(self.device)
        self.q_model.to(self.device)
        self.target_q_model.to(self.device)

    def data_parallel(self):
        device_id = super().data_parallel()  # Takes care of self.model.
        self.q_model = DDP(
            self.q_model,
            device_ids=None if device_id is None else [device_id],  # 1 GPU.
            output_device=device_id,
        )
        return device_id

    def make_env_to_model_kwargs(self, env_spaces):
        assert len(env_spaces.action.shape) == 1
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.shape[0],
        )

    def q(self, observation, prev_action, prev_reward, action):
        """Compute Q-value for input state/observation and action (with grad)."""
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        q = self.q_model(*model_inputs)
        return q.cpu()

    def q_at_mu(self, observation, prev_action, prev_reward):
        """Compute Q-value for input state/observation, through the mu_model
        (with grad)."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mu = self.model(*model_inputs)
        q = self.q_model(*model_inputs, mu)
        return q.cpu()

    def target_q_at_mu(self, observation, prev_action, prev_reward):
        """Compute target Q-value for input state/observation, through the
        target mu_model."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        target_mu = self.target_model(*model_inputs)
        target_q_at_mu = self.target_q_model(*model_inputs, target_mu)
        return target_q_at_mu.cpu()

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        """Computes distribution parameters (mu) for state/observation,
        returns (gaussian) sampled action."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mu = self.model(*model_inputs)
        action = self.distribution.sample(DistInfo(mean=mu))
        agent_info = AgentInfo(mu=mu)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_model, self.model.state_dict(), tau)
        update_state_dict(self.target_q_model, self.q_model.state_dict(), tau)

    def q_parameters(self):
        return self.q_model.parameters()

    def mu_parameters(self):
        return self.model.parameters()

    def train_mode(self, itr):
        super().train_mode(itr)
        self.q_model.train()

    def sample_mode(self, itr):
        super().sample_mode(itr)
        self.q_model.eval()
        self.distribution.set_std(self.action_std)

    def eval_mode(self, itr):
        super().eval_mode(itr)
        self.q_model.eval()
        self.distribution.set_std(0.)  # Deterministic.

    def state_dict(self):
        return dict(
            model=self.model.state_dict(),
            q_model=self.q_model.state_dict(),
            target_model=self.target_model.state_dict(),
            target_q_model=self.target_q_model.state_dict(),
        )

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.q_model.load_state_dict(state_dict["q_model"])
        self.target_model.load_state_dict(state_dict["target_model"])
        self.target_q_model.load_state_dict(state_dict["target_q_model"])
예제 #3
0
파일: sac_agent.py 프로젝트: keirp/glamor
class SacAgent(BaseAgent):
    """Agent for SAC algorithm, including action-squashing, using twin Q-values."""

    def __init__(
            self,
            ModelCls=PiMlpModel,  # Pi model.
            QModelCls=QofMuMlpModel,
            model_kwargs=None,  # Pi model.
            q_model_kwargs=None,
            v_model_kwargs=None,
            initial_model_state_dict=None,  # All models.
            action_squash=1.,  # Max magnitude (or None).
            pretrain_std=0.75,  # With squash 0.75 is near uniform.
    ):
        """Saves input arguments; network defaults stored within."""
        if model_kwargs is None:
            model_kwargs = dict(hidden_sizes=[256, 256])
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[256, 256])
        if v_model_kwargs is None:
            v_model_kwargs = dict(hidden_sizes=[256, 256])
        super().__init__(ModelCls=ModelCls, model_kwargs=model_kwargs,
                         initial_model_state_dict=initial_model_state_dict)
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

    def initialize(self, env_spaces, share_memory=False,
                   global_B=1, env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        # Don't let base agent try to load.
        self.initial_model_state_dict = None
        super().initialize(env_spaces, share_memory,
                           global_B=global_B, env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict
        self.q1_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
        self.q2_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
        self.target_q1_model = self.QModelCls(**self.env_model_kwargs,
                                              **self.q_model_kwargs)
        self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
                                              **self.q_model_kwargs)
        self.target_q1_model.load_state_dict(self.q1_model.state_dict())
        self.target_q2_model.load_state_dict(self.q2_model.state_dict())
        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)
        self.q1_model.to(self.device)
        self.q2_model.to(self.device)
        self.target_q1_model.to(self.device)
        self.target_q2_model.to(self.device)

    def data_parallel(self):
        device_id = super().data_parallel
        self.q1_model = DDP(
            self.q1_model,
            device_ids=None if device_id is None else [device_id],  # 1 GPU.
            output_device=device_id,
        )
        self.q2_model = DDP(
            self.q2_model,
            device_ids=None if device_id is None else [device_id],  # 1 GPU.
            output_device=device_id,
        )
        return device_id

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    def make_env_to_model_kwargs(self, env_spaces):
        assert len(env_spaces.action.shape) == 1
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.shape[0],
        )

    def q(self, observation, prev_action, prev_reward, action):
        """Compute twin Q-values for state/observation and input action 
        (with grad)."""
        model_inputs = buffer_to((observation, prev_action, prev_reward,
                                  action), device=self.device)
        q1 = self.q1_model(*model_inputs)
        q2 = self.q2_model(*model_inputs)
        return q1.cpu(), q2.cpu()

    def target_q(self, observation, prev_action, prev_reward, action):
        """Compute twin target Q-values for state/observation and input
        action."""
        model_inputs = buffer_to((observation, prev_action,
                                  prev_reward, action), device=self.device)
        target_q1 = self.target_q1_model(*model_inputs)
        target_q2 = self.target_q2_model(*model_inputs)
        return target_q1.cpu(), target_q2.cpu()

    def pi(self, observation, prev_action, prev_reward):
        """Compute action log-probabilities for state/observation, and
        sample new action (with grad).  Uses special ``sample_loglikelihood()``
        method of Gaussian distriution, which handles action squashing
        through this process."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mean, log_std = self.model(*model_inputs)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action, log_pi = self.distribution.sample_loglikelihood(dist_info)
        # action = self.distribution.sample(dist_info)
        # log_pi = self.distribution.log_likelihood(action, dist_info)
        log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
        # Action stays on device for q models.
        return action, log_pi, dist_info

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mean, log_std = self.model(*model_inputs)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_q1_model,
                          self.q1_model.state_dict(), tau)
        update_state_dict(self.target_q2_model,
                          self.q2_model.state_dict(), tau)

    @property
    def models(self):
        return Models(pi=self.model, q1=self.q1_model, q2=self.q2_model)

    def pi_parameters(self):
        return self.model.parameters()

    def q1_parameters(self):
        return self.q1_model.parameters()

    def q2_parameters(self):
        return self.q2_model.parameters()

    def train_mode(self, itr):
        super().train_mode(itr)
        self.q1_model.train()
        self.q2_model.train()

    def sample_mode(self, itr):
        super().sample_mode(itr)
        self.q1_model.eval()
        self.q2_model.eval()
        if itr == 0:
            logger.log(f"Agent at itr {itr}, sample std: {self.pretrain_std}")
        if itr == self.min_itr_learn:
            logger.log(f"Agent at itr {itr}, sample std: learned.")
        std = None if itr >= self.min_itr_learn else self.pretrain_std
        self.distribution.set_std(std)  # If None: std from policy dist_info.

    def eval_mode(self, itr):
        super().eval_mode(itr)
        self.q1_model.eval()
        self.q2_model.eval()
        self.distribution.set_std(0.)  # Deterministic (dist_info std ignored).

    def state_dict(self):
        return dict(
            model=self.model.state_dict(),  # Pi model.
            q1_model=self.q1_model.state_dict(),
            q2_model=self.q2_model.state_dict(),
            target_q1_model=self.target_q1_model.state_dict(),
            target_q2_model=self.target_q2_model.state_dict(),
        )

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.q1_model.load_state_dict(state_dict["q1_model"])
        self.q2_model.load_state_dict(state_dict["q2_model"])
        self.target_q1_model.load_state_dict(state_dict["target_q1_model"])
        self.target_q2_model.load_state_dict(state_dict["target_q2_model"])
예제 #4
0
class SacAgent(BaseAgent):
    def __init__(
        self,
        ModelCls=SacModel,
        ConvModelCls=SacConvModel,
        Fc1ModelCls=SacFc1Model,
        PiModelCls=SacActorModel,
        QModelCls=SacCriticModel,
        conv_kwargs=None,
        fc1_kwargs=None,
        pi_model_kwargs=None,
        q_model_kwargs=None,
        initial_state_dict=None,
        action_squash=1.0,
        pretrain_std=0.75,  # 0.75 gets pretty uniform squashed actions
        load_conv=False,
        load_all=False,
        state_dict_filename=None,
        store_latent=False,
    ):
        if conv_kwargs is None:
            conv_kwargs = dict()
        if fc1_kwargs is None:
            fc1_kwargs = dict(latent_size=50)  # default
        if pi_model_kwargs is None:
            pi_model_kwargs = dict(hidden_sizes=[1024, 1024])  # default
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[1024, 1024])  # default
        save__init__args(locals())
        super().__init__(ModelCls=SacModel)
        self.min_itr_learn = 0  # Get from algo.
        assert not (load_conv and load_all)

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        self.conv = self.ConvModelCls(image_shape=env_spaces.observation.shape,
                                      **self.conv_kwargs)
        self.q_fc1 = self.Fc1ModelCls(input_size=self.conv.output_size,
                                      **self.fc1_kwargs)
        self.pi_fc1 = self.Fc1ModelCls(input_size=self.conv.output_size,
                                       **self.fc1_kwargs)

        latent_size = self.q_fc1.output_size
        action_size = env_spaces.action.shape[0]

        # These are just MLPs
        self.pi_mlp = self.PiModelCls(input_size=latent_size,
                                      action_size=action_size,
                                      **self.pi_model_kwargs)
        self.q_mlps = self.QModelCls(input_size=latent_size,
                                     action_size=action_size,
                                     **self.q_model_kwargs)
        self.target_q_mlps = copy.deepcopy(self.q_mlps)  # Separate params.

        # Make reference to the full actor model including encoder.
        # CAREFUL ABOUT TRAIN MODE FOR LAYER NORM IF CHANGING THIS?
        self.model = SacModel(conv=self.conv,
                              pi_fc1=self.pi_fc1,
                              pi_mlp=self.pi_mlp)

        if self.load_conv:
            logger.log("Agent loading state dict: " + self.state_dict_filename)
            loaded_state_dict = torch.load(self.state_dict_filename,
                                           map_location=torch.device("cpu"))
            # From UL, saves snapshot: params["algo_state_dict"]["encoder"]
            if "algo_state_dict" in loaded_state_dict:
                loaded_state_dict = loaded_state_dict
            loaded_state_dict = loaded_state_dict.get("algo_state_dict",
                                                      loaded_state_dict)
            loaded_state_dict = loaded_state_dict.get("encoder",
                                                      loaded_state_dict)
            # A bit onerous, but ensures that state dicts match:
            conv_state_dict = OrderedDict([
                (k, v)  # .replace("conv.", "", 1)
                for k, v in loaded_state_dict.items() if k.startswith("conv.")
            ])
            self.conv.load_state_dict(conv_state_dict)
            # Double check it gets into the q_encoder as well.
            logger.log("Agent loaded CONV state dict.")
        elif self.load_all:
            # From RL, saves snapshot: params["agent_state_dict"]
            loaded_state_dict = torch.load(self.state_dict_filename,
                                           map_location=torch.device("cpu"))
            self.load_state_dict(loaded_state_dict["agent_state_dict"])
            logger.log("Agnet loaded FULL state dict.")
        else:
            logger.log("Agent NOT loading state dict.")

        self.target_conv = copy.deepcopy(self.conv)
        self.target_q_fc1 = copy.deepcopy(self.q_fc1)

        if share_memory:
            # The actor model needs to share memory to sampler workers, and
            # this includes handling the encoder!
            # (Almost always just run serial anyway, no sharing.)
            self.model.share_memory()
            self.shared_model = self.model
        if self.initial_state_dict is not None:
            raise NotImplementedError
        self.env_spaces = env_spaces
        self.share_memory = share_memory

        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            # min_std=np.exp(MIN_LOG_STD),  # NOPE IN PI_MODEL NOW
            # max_std=np.exp(MAX_LOG_STD),
        )

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)  # Takes care of self.model only.
        self.conv.to(self.device)  # should already be done
        self.q_fc1.to(self.device)
        self.pi_fc1.to(self.device)  # should already be done
        self.q_mlps.to(self.device)
        self.pi_mlp.to(self.device)  # should already be done
        self.target_conv.to(self.device)
        self.target_q_fc1.to(self.device)
        self.target_q_mlps.to(self.device)

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        observation, prev_action, prev_reward = buffer_to(
            (observation, prev_action, prev_reward), device=self.device)
        # self.model includes encoder + actor MLP.
        mean, log_std, latent, conv = self.model(observation, prev_action,
                                                 prev_reward)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info,
                               conv=conv if self.store_latent else None)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def q(self, conv_out, prev_action, prev_reward, action):
        """Compute twin Q-values for state/observation and input action
        (with grad).
        Assume variables already on device."""
        latent = self.q_fc1(conv_out)
        q1, q2 = self.q_mlps(latent, action, prev_action, prev_reward)
        return q1.cpu(), q2.cpu()

    def target_q(self, conv_out, prev_action, prev_reward, action):
        """Compute twin target Q-values for state/observation and input
        action.
        Assume variables already on device."""
        latent = self.target_q_fc1(conv_out)
        target_q1, target_q2 = self.target_q_mlps(latent, action, prev_action,
                                                  prev_reward)
        return target_q1.cpu(), target_q2.cpu()

    def pi(self, conv_out, prev_action, prev_reward):
        """Compute action log-probabilities for state/observation, and
        sample new action (with grad).  Uses special ``sample_loglikelihood()``
        method of Gaussian distriution, which handles action squashing
        through this process.
        Assume variables already on device."""
        # Call just the actor mlp, not the encoder.
        latent = self.pi_fc1(conv_out)
        mean, log_std = self.pi_mlp(latent, prev_action, prev_reward)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action, log_pi = self.distribution.sample_loglikelihood(dist_info)
        # action = self.distribution.sample(dist_info)
        # log_pi = self.distribution.log_likelihood(action, dist_info)
        log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
        return action, log_pi, dist_info  # Action stays on device for q models.

    def train_mode(self, itr):
        super().train_mode(itr)  # pi_encoder in here in model
        self.conv.train()  # should already be done
        self.q_fc1.train()
        self.pi_fc1.train()  # should already be done
        self.q_mlps.train()
        self.pi_mlp.train()  # should already be done

    def sample_mode(self, itr):
        super().sample_mode(itr)  # pi_encoder in here in model
        self.conv.eval()  # should already be done
        self.q_fc1.eval()
        self.pi_fc1.eval()  # should already be done
        self.q_mlps.eval()  # not used anyway
        self.pi_mlp.eval()  # should already be done
        if itr == 0:
            logger.log(f"Agent at itr {itr}, sample std: {self.pretrain_std}")
        if itr == self.min_itr_learn:
            logger.log(f"Agent at itr {itr}, sample std: learned.")
        std = None if itr >= self.min_itr_learn else self.pretrain_std
        self.distribution.set_std(std)  # If None: std from policy dist_info.

    def eval_mode(self, itr):
        super().eval_mode(itr)  # pi_encoder in here in model
        self.conv.eval()  # should already be done
        self.q_fc1.eval()
        self.pi_fc1.eval()  # should already be done
        self.q_mlps.eval()  # not used anyway
        self.pi_mlp.eval()  # should already be done
        self.distribution.set_std(
            0.0)  # Deterministic (dist_info std ignored).

    def state_dict(self):
        return dict(
            conv=self.conv.state_dict(),
            q_fc1=self.q_fc1.state_dict(),
            pi_fc1=self.pi_fc1.state_dict(),
            q_mlps=self.q_mlps.state_dict(),
            pi_mlp=self.pi_mlp.state_dict(),
            target_conv=self.target_conv.state_dict(),
            target_q_fc1=self.target_q_fc1.state_dict(),
            target_q_mlps=self.target_q_mlps.state_dict(),
        )

    def load_state_dict(self, state_dict):
        self.conv.load_state_dict(state_dict["conv"])
        self.q_fc1.load_state_dict(state_dict["q_fc1"])
        self.pi_fc1.load_state_dict(state_dict["pi_fc1"])
        self.q_mlps.load_state_dict(state_dict["q_mlps"])
        self.pi_mlp.load_state_dict(state_dict["pi_mlp"])
        self.target_conv.load_state_dict(state_dict["target_conv"])
        self.target_q_fc1.load_state_dict(state_dict["target_q_fc1"])
        self.target_q_mlps.load_state_dict(state_dict["target_q_mlps"])

    def data_parallel(self, *args, **kwargs):
        raise NotImplementedError  # Do it later.

    def async_cpu(self, *args, **kwargs):
        raise NotImplementedError  # Double check this...

    def update_targets(self, q_tau=1, encoder_tau=1):
        """Do each parameter ONLY ONCE."""
        update_state_dict(self.target_conv, self.conv.state_dict(),
                          encoder_tau)
        update_state_dict(self.target_q_fc1, self.q_fc1.state_dict(),
                          encoder_tau)
        update_state_dict(self.target_q_mlps, self.q_mlps.state_dict(), q_tau)
예제 #5
0
class DdpgAgent(BaseAgent):

    shared_mu_model = None

    def __init__(
        self,
        ModelCls=MuMlpModel,  # Mu model.
        QModelCls=QofMuMlpModel,
        model_kwargs=None,  # Mu model.
        q_model_kwargs=None,
        initial_model_state_dict=None,  # Mu model.
        initial_q_model_state_dict=None,
        action_std=0.1,
        action_noise_clip=None,
    ):
        if model_kwargs is None:
            model_kwargs = dict(hidden_sizes=[400, 300])
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[400, 300])
        save__init__args(locals())
        super().__init__()  # For async setup.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.q_model = self.QModelCls(**self.env_model_kwargs,
                                      **self.q_model_kwargs)
        if self.initial_q_model_state_dict is not None:
            self.q_model.load_state_dict(self.initial_q_model_state_dict)
        self.target_model = self.ModelCls(**self.env_model_kwargs,
                                          **self.model_kwargs)
        self.target_q_model = self.QModelCls(**self.env_model_kwargs,
                                             **self.q_model_kwargs)
        self.target_q_model.load_state_dict(self.q_model.state_dict())
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            std=self.action_std,
            noise_clip=self.action_noise_clip,
            clip=env_spaces.action.high[0],  # Assume symmetric low=-high.
        )

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)  # Takes care of self.model.
        self.target_model.to(self.device)
        self.q_model.to(self.device)
        self.target_q_model.to(self.device)

    def data_parallel(self):
        super().data_parallel()  # Takes care of self.model.
        if self.device.type == "cpu":
            self.q_model = DDPC(self.q_model)
        else:
            self.q_model = DDP(self.q_model)

    def make_env_to_model_kwargs(self, env_spaces):
        assert len(env_spaces.action.shape) == 1
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.shape[0],
        )

    def q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        q = self.q_model(*model_inputs)
        return q.cpu()

    def q_at_mu(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mu = self.model(*model_inputs)
        q = self.q_model(*model_inputs, mu)
        return q.cpu()

    def target_q_at_mu(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        target_mu = self.target_model(*model_inputs)
        target_q_at_mu = self.target_q_model(*model_inputs, target_mu)
        return target_q_at_mu.cpu()

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mu = self.model(*model_inputs)
        action = self.distribution.sample(DistInfo(mean=mu))
        agent_info = AgentInfo(mu=mu)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_model, self.model.state_dict(), tau)
        update_state_dict(self.target_q_model, self.q_model.state_dict(), tau)

    def q_parameters(self):
        return self.q_model.parameters()

    def mu_parameters(self):
        return self.model.parameters()

    def train_mode(self, itr):
        super().train_mode(itr)
        self.q_model.train()

    def sample_mode(self, itr):
        super().sample_mode(itr)
        self.q_model.eval()
        self.distribution.set_std(self.action_std)

    def eval_mode(self, itr):
        super().eval_mode(itr)
        self.q_model.eval()
        self.distribution.set_std(0.)  # Deterministic.

    def state_dict(self):
        return dict(
            model=self.model.state_dict(),
            q_model=self.q_model.state_dict(),
            target_model=self.target_model.state_dict(),
            target_q_model=self.target_q_model.state_dict(),
        )

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.q_model.load_state_dict(state_dict["q_model"])
        self.target_model.load_state_dict(state_dict["target_model"])
        self.target_q_model.load_state_dict(state_dict["target_q_model"])
예제 #6
0
class Td3Agent(DdpgAgent):
    def __init__(
            self,
            pretrain_std=2.,  # To make actions roughly uniform.
            target_noise_std=0.2,
            target_noise_clip=0.5,
            initial_q2_model_state_dict=None,
            **kwargs):
        super().__init__(**kwargs)
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces, share_memory, global_B, env_ranks)
        self.q2_model = self.QModelCls(**self.env_model_kwargs,
                                       **self.q_model_kwargs)
        if self.initial_q2_model_state_dict is not None:
            self.q2_model.load_state_dict(self.initial_q2_model_state_dict)
        self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
                                              **self.q_model_kwargs)
        self.target_q2_model.load_state_dict(self.q2_model.state_dict())
        self.target_distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            std=self.target_noise_std,
            noise_clip=self.target_noise_clip,
            clip=env_spaces.action.high[0],  # Assume symmetric low=-high.
        )

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)
        self.q2_model.to(self.device)
        self.target_q2_model.to(self.device)

    def data_parallel(self):
        super().data_parallel()
        if self.device.type == "cpu":
            self.q2_model = DDPC(self.q2_model)
        else:
            self.q2_model = DDP(self.q2_model)

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    def q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        q1 = self.q_model(*model_inputs)
        q2 = self.q2_model(*model_inputs)
        return q1.cpu(), q2.cpu()

    def target_q_at_mu(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        target_mu = self.target_model(*model_inputs)
        target_action = self.target_distribution.sample(
            DistInfo(mean=target_mu))
        target_q1_at_mu = self.target_q_model(*model_inputs, target_action)
        target_q2_at_mu = self.target_q2_model(*model_inputs, target_action)
        return target_q1_at_mu.cpu(), target_q2_at_mu.cpu()

    def update_target(self, tau=1):
        super().update_target(tau)
        update_state_dict(self.target_q2_model, self.q2_model.state_dict(),
                          tau)

    def q_parameters(self):
        yield from self.q_model.parameters()
        yield from self.q2_model.parameters()

    def set_target_noise(self, std, noise_clip=None):
        self.target_distribution.set_std(std)
        self.target_distribution.set_noise_clip(noise_clip)

    def train_mode(self, itr):
        super().train_mode(itr)
        self.q2_model.train()

    def sample_mode(self, itr):
        super().sample_mode(itr)
        self.q2_model.eval()
        std = self.action_std if itr >= self.min_itr_learn else self.pretrain_std
        if itr == 0 or itr == self.min_itr_learn:
            logger.log(f"Agent at itr {itr}, sample std: {std}.")
        self.distribution.set_std(std)

    def eval_mode(self, itr):
        super().eval_mode(itr)
        self.q2_model.eval()

    def state_dict(self):
        state_dict = super().state_dict()
        state_dict["q2_model"] = self.q2_model.state_dict()
        state_dict["target_q2_model"] = self.target_q2_model.state_dict()
        return state_dict

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.q2_model.load_state_dict(state_dict["q2_model"])
        self.target_q2_model.load_state_dict(state_dict["target_q2_model"])
예제 #7
0
파일: mlp.py 프로젝트: Xingyu-Lin/softagent
class GumbelPiMlpModel(torch.nn.Module):
    """For picking corners"""

    def __init__(
            self,
            observation_shape,
            hidden_sizes,
            action_size,
            all_corners=False
            ):
        super().__init__()
        self._obs_ndim = 1
        self._all_corners = all_corners
        input_dim = int(np.sum(observation_shape))

        print('all corners', self._all_corners)
        delta_dim = 12 if all_corners else 3
        self._delta_dim = delta_dim
        self.mlp = MlpModel(
            input_size=input_dim,
            hidden_sizes=hidden_sizes,
            output_size=2 * delta_dim + 4, # 3 for each corners, times two for std, 4 probs
        )

        self.delta_distribution = Gaussian(
            dim=delta_dim,
            squash=True,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
        self.cat_distribution = Categorical(4)


    def forward(self, observation, prev_action, prev_reward):
        if isinstance(observation, tuple):
            observation = torch.cat(observation, dim=-1)

        lead_dim, T, B, _ = infer_leading_dims(observation,
            self._obs_ndim)
        output = self.mlp(observation.view(T * B, -1))
        logits = output[:, :4]
        mu, log_std = output[:, 4:4 + self._delta_dim], output[:, 4 + self._delta_dim:]
        logits, mu, log_std = restore_leading_dims((logits, mu, log_std), lead_dim, T, B)
        return GumbelDistInfo(cat_dist=logits, delta_dist=DistInfoStd(mean=mu, log_std=log_std))

    def sample_loglikelihood(self, dist_info):
        logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist

        u = torch.rand_like(logits)
        u = torch.clamp(u, 1e-5, 1 - 1e-5)
        gumbel = -torch.log(-torch.log(u))
        prob = F.softmax((logits + gumbel) / 10, dim=-1)

        cat_sample = torch.argmax(prob, dim=-1)
        cat_loglikelihood = select_at_indexes(cat_sample, prob)

        one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)
        one_hot = (one_hot - prob).detach() + prob # Make action differentiable through prob

        if self._all_corners:
            mu, log_std = delta_dist_info.mean, delta_dist_info.log_std
            mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3)
            mu = mu[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)]
            log_std = log_std[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)]
            new_dist_info = DistInfoStd(mean=mu, log_std=log_std)
        else:
            new_dist_info = delta_dist_info

        delta_sample, delta_loglikelihood = self.delta_distribution.sample_loglikelihood(new_dist_info)
        action = torch.cat((one_hot, delta_sample), dim=-1)
        log_likelihood = cat_loglikelihood + delta_loglikelihood
        return action, log_likelihood

    def sample(self, dist_info):
        logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist
        u = torch.rand_like(logits)
        u = torch.clamp(u, 1e-5, 1 - 1e-5)
        gumbel = -torch.log(-torch.log(u))
        prob = F.softmax((logits + gumbel) / 10, dim=-1)

        cat_sample = torch.argmax(prob, dim=-1)
        one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)

        if len(prob.shape) == 1: # Edge case for when it gets buffer shapes
            cat_sample = cat_sample.unsqueeze(0)

        if self._all_corners:
            mu, log_std = delta_dist_info.mean, delta_dist_info.log_std
            mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3)
            mu = select_at_indexes(cat_sample, mu)
            log_std = select_at_indexes(cat_sample, log_std)

            if len(prob.shape) == 1: # Edge case for when it gets buffer shapes
                mu, log_std = mu.squeeze(0), log_std.squeeze(0)

            new_dist_info = DistInfoStd(mean=mu, log_std=log_std)
        else:
            new_dist_info = delta_dist_info

        if self.training:
            self.delta_distribution.set_std(None)
        else:
            self.delta_distribution.set_std(0)
        delta_sample = self.delta_distribution.sample(new_dist_info)
        return torch.cat((one_hot, delta_sample), dim=-1)
예제 #8
0
파일: mlp.py 프로젝트: Xingyu-Lin/softagent
class GumbelAutoregPiMlpModel(torch.nn.Module):
    """For picking corners autoregressively"""

    def __init__(
            self,
            observation_shape,
            hidden_sizes,
            action_size,
            n_tile=20,
    ):
        super().__init__()
        self._obs_ndim = 1
        self._n_tile = n_tile
        input_dim = int(np.sum(observation_shape))

        self._action_size = action_size
        self.mlp_loc = MlpModel(
            input_size=input_dim,
            hidden_sizes=hidden_sizes,
            output_size=4
        )
        self.mlp_delta = MlpModel(
            input_size=input_dim + 4 * n_tile,
            hidden_sizes=hidden_sizes,
            output_size=3 * 2,
        )

        self.delta_distribution = Gaussian(
            dim=3,
            squash=True,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
        self.cat_distribution = Categorical(4)

        self._counter = 0

    def start(self):
        self._counter = 0

    def next(self, actions, observation, prev_action, prev_reward):
        if isinstance(observation, tuple):
            observation = torch.cat(observation, dim=-1)

        lead_dim, T, B, _ = infer_leading_dims(observation,
                                               self._obs_ndim)
        input_obs = observation.view(T * B, -1)
        if self._counter == 0:
            logits = self.mlp_loc(input_obs)
            logits = restore_leading_dims(logits, lead_dim, T, B)
            self._counter += 1
            return logits

        elif self._counter == 1:
            assert len(actions) == 1
            action_loc = actions[0].view(T * B, -1)
            model_input = torch.cat((input_obs, action_loc.repeat((1, self._n_tile))), dim=-1)
            output = self.mlp_delta(model_input)
            mu, log_std = output.chunk(2, dim=-1)
            mu, log_std = restore_leading_dims((mu, log_std), lead_dim, T, B)
            self._counter += 1
            return DistInfoStd(mean=mu, log_std=log_std)
        else:
            raise Exception('Invalid self._counter', self._counter)

    def has_next(self):
        return self._counter < 2

    def sample_loglikelihood(self, dist_info):
        if isinstance(dist_info, DistInfoStd):
            action, log_likelihood = self.delta_distribution.sample_loglikelihood(dist_info)
        else:
            logits = dist_info

            u = torch.rand_like(logits)
            u = torch.clamp(u, 1e-5, 1 - 1e-5)
            gumbel = -torch.log(-torch.log(u))
            prob = F.softmax((logits + gumbel) / 10, dim=-1)

            cat_sample = torch.argmax(prob, dim=-1)
            log_likelihood = select_at_indexes(cat_sample, prob)

            one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)
            action = (one_hot - prob).detach() + prob  # Make action differentiable through prob

        return action, log_likelihood

    def sample(self, dist_info):
        if isinstance(dist_info, DistInfoStd):
            if self.training:
                self.delta_distribution.set_std(None)
            else:
                self.delta_distribution.set_std(0)
            action = self.delta_distribution.sample(dist_info)
        else:
            logits = dist_info
            u = torch.rand_like(logits)
            u = torch.clamp(u, 1e-5, 1 - 1e-5)
            gumbel = -torch.log(-torch.log(u))
            prob = F.softmax((logits + gumbel) / 10, dim=-1)

            cat_sample = torch.argmax(prob, dim=-1)
            action = to_onehot(cat_sample, 4, dtype=torch.float32)

        return action
예제 #9
0
class DdpgAgent(BaseAgent):

    shared_mu_model = None

    def __init__(
        self,
        MuModelCls=MuMlpModel,
        QModelCls=QofMuMlpModel,
        mu_model_kwargs=None,
        q_model_kwargs=None,
        initial_mu_model_state_dict=None,
        initial_q_model_state_dict=None,
        action_std=0.1,
        action_noise_clip=None,
    ):
        if mu_model_kwargs is None:
            mu_model_kwargs = dict(hidden_sizes=[400, 300])
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[400, 300])
        save__init__args(locals())
        self.min_itr_learn = 0  # Used in TD3

    def initialize(self, env_spaces, share_memory=False):
        env_model_kwargs = self.make_env_to_model_kwargs(env_spaces)
        self.mu_model = self.MuModelCls(**env_model_kwargs,
                                        **self.mu_model_kwargs)
        self.q_model = self.QModelCls(**env_model_kwargs,
                                      **self.q_model_kwargs)
        if share_memory:
            self.mu_model.share_memory()
            # self.q_model.share_memory()  # Not needed for sampling.
            self.shared_mu_model = self.mu_model
            # self.shared_q_model = self.q_model
        if self.initial_mu_model_state_dict is not None:
            self.mu_model.load_state_dict(self.initial_mu_model_state_dict)
        if self.initial_q_model_state_dict is not None:
            self.q_model.load_state_dict(self.initial_q_model_state_dict)
        self.target_mu_model = self.MuModelCls(**env_model_kwargs,
                                               **self.mu_model_kwargs)
        self.target_mu_model.load_state_dict(self.mu_model.state_dict())
        self.target_q_model = self.QModelCls(**env_model_kwargs,
                                             **self.q_model_kwargs)
        self.target_q_model.load_state_dict(self.q_model.state_dict())
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            std=self.action_std,
            noise_clip=self.action_noise_clip,
            clip=env_spaces.action.high[0],  # Assume symmetric low=-high.
        )
        self.env_spaces = env_spaces
        self.env_model_kwargs = env_model_kwargs

    def initialize_cuda(self, cuda_idx=None, ddp=False):
        if cuda_idx is None:
            return  # CPU
        if self.shared_mu_model is not None:
            self.mu_model = self.MuModelCls(**self.env_model_kwargs,
                                            **self.mu_model_kwargs)
            self.mu_model.load_state_dict(self.shared_mu_model.state_dict())
        self.device = torch.device("cuda", index=cuda_idx)
        self.mu_model.to(self.device)
        self.q_model.to(self.device)
        if ddp:
            self.mu_model = DDP(self.mu_model,
                                device_ids=[cuda_idx],
                                output_device=cuda_idx)
            self.q_model = DDP(self.q_model,
                               device_ids=[cuda_idx],
                               output_device=cuda_idx)
            logger.log("Initialized DistributedDataParallel agent model "
                       f"on device: {self.device}.")
        else:
            logger.log(f"Initialized agent models on device: {self.device}.")
        self.target_mu_model.to(self.device)
        self.target_q_model.to(self.device)

    def make_env_to_model_kwargs(self, env_spaces):
        assert len(env_spaces.action.shape) == 1
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.shape[0],
        )

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # Used in TD3

    def q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        q = self.q_model(*model_inputs)
        return q.cpu()

    def q_at_mu(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mu = self.mu_model(*model_inputs)
        q = self.q_model(*model_inputs, mu)
        return q.cpu()

    def target_q_at_mu(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        target_mu = self.target_mu_model(*model_inputs)
        target_q_at_mu = self.target_q_model(*model_inputs, target_mu)
        return target_q_at_mu.cpu()

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mu = self.mu_model(*model_inputs)
        action = self.distribution.sample(DistInfo(mean=mu))
        agent_info = AgentInfo(mu=mu)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_mu_model, self.mu_model, tau)
        update_state_dict(self.target_q_model, self.q_model, tau)

    def sync_shared_memory(self):
        if self.shared_mu_model is not self.mu_model:
            self.shared_mu_model.load_state_dict(self.mu_model.state_dict())

    def recv_shared_memory(self):
        with self._rw_lock:
            if self.mu_model is not self.shared_mu_model:
                self.mu_model.load_state_dict(self.shared_mu_model)

    def q_parameters(self):
        return self.q_model.parameters()

    def mu_parameters(self):
        return self.mu_model.parameters()

    def train_mode(self, itr):
        self.q_model.train()
        self.mu_model.train()
        self._mode = "train"

    def sample_mode(self, itr):
        self.q_model.eval()
        self.mu_model.eval()
        self.distribution.set_std(self.action_std)
        self._mode = "sample"

    def eval_mode(self, itr):
        self.q_model.eval()
        self.mu_model.eval()
        self.distribution.set_std(0.)  # Deterministic
        self._mode = "eval"

    def state_dict(self):
        return dict(
            q_model=self.q_model.state_dict(),
            mu_model=self.mu_model.state_dict(),
            q_target=self.target_q_model.state_dict(),
            mu_target=self.target_mu_model.state_dict(),
        )
예제 #10
0
class SacAgent(BaseAgent):

    shared_pi_model = None

    def __init__(
        self,
        ModelCls=PiMlpModel,  # Pi model.
        QModelCls=QofMuMlpModel,
        model_kwargs=None,  # Pi model.
        q_model_kwargs=None,
        initial_model_state_dict=None,  # Pi model.
        action_squash=1,  # Max magnitude (or None).
        pretrain_std=0.75,  # High value to make near uniform sampling.
        max_q_eval_mode='none',
        n_qs=2,
    ):
        self._max_q_eval_mode = max_q_eval_mode
        if isinstance(ModelCls, str):
            ModelCls = eval(ModelCls)
        if isinstance(QModelCls, str):
            QModelCls = eval(QModelCls)

        if model_kwargs is None:
            model_kwargs = dict(hidden_sizes=[256, 256])
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[256, 256])
        super().__init__(ModelCls=ModelCls,
                         model_kwargs=model_kwargs,
                         initial_model_state_dict=initial_model_state_dict
                         )  # For async setup.
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

        self.log_alpha = None
        print('n_qs', self.n_qs)

        global Models
        Models = namedtuple("Models",
                            ["pi"] + [f"q{i}" for i in range(self.n_qs)])

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        self.initial_model_state_dict = None
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict
        self.q_models = [
            self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
            for _ in range(self.n_qs)
        ]

        self.target_q_models = [
            self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
            for _ in range(self.n_qs)
        ]
        [
            target_q.load_state_dict(q.state_dict())
            for target_q, q in zip(self.target_q_models, self.q_models)
        ]

        self.log_alpha = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))

        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)
        [q.to(self.device) for q in self.q_models]
        [q_target.to(self.device) for q_target in self.target_q_models]
        self.log_alpha.to(self.device)

    def data_parallel(self):
        super().data_parallel()
        DDP_WRAP = DDPC if self.device.type == "cpu" else DDP
        self.q_models = [DDP_WRAP(q) for q in self.q_models]

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    def make_env_to_model_kwargs(self, env_spaces):
        assert len(env_spaces.action.shape) == 1
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.shape[0],
        )

    def q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        qs = [q(*model_inputs) for q in self.q_models]
        return [q.cpu() for q in qs]

    def target_q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        qs = [target_q(*model_inputs) for target_q in self.target_q_models]
        return [q.cpu() for q in qs]

    def pi(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mean, log_std = self.model(*model_inputs)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action, log_pi = self.distribution.sample_loglikelihood(dist_info)
        log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
        return action, log_pi, dist_info  # Action stays on device for q models.

    def target_v(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        next_actions, next_log_pis, _ = self.pi(*model_inputs)

        qs = self.target_q(observation, prev_action, prev_reward, next_actions)
        min_next_q = torch.min(torch.stack(qs, dim=0), dim=0)[0]

        target_v = min_next_q - self.log_alpha.exp().detach().cpu(
        ) * next_log_pis
        return target_v.cpu()

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        threshold = 0.2
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        if self._max_q_eval_mode == 'none':
            mean, log_std = self.model(*model_inputs)
            dist_info = DistInfoStd(mean=mean, log_std=log_std)
            action = self.distribution.sample(dist_info)
            agent_info = AgentInfo(dist_info=dist_info)
            action, agent_info = buffer_to((action, agent_info), device="cpu")
            return AgentStep(action=action, agent_info=agent_info)
        else:
            global MaxQInput
            observation, prev_action, prev_reward = model_inputs
            fields = observation._fields
            if 'position' in fields:
                no_batch = len(observation.position.shape) == 1
            else:
                no_batch = len(observation.pixels.shape) == 3
            if no_batch:
                if 'state' in self._max_q_eval_mode:
                    observation = [observation.position.unsqueeze(0)]
                else:
                    observation = [observation.pixels.unsqueeze(0)]
            else:
                if 'state' in self._max_q_eval_mode:
                    observation = [observation.position]
                else:
                    observation = [observation.pixels]

            if self._max_q_eval_mode == 'state_rope':
                locations = np.arange(25).astype('float32')
                locations = locations[:, None]
                locations = np.tile(locations, (1, 50)) / 24
            elif self._max_q_eval_mode == 'state_cloth_corner':
                locations = np.array(
                    [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
                    dtype='float32')
                locations = np.tile(locations, (1, 50))
            elif self._max_q_eval_mode == 'state_cloth_point':
                locations = np.mgrid[0:9, 0:9].reshape(2,
                                                       81).T.astype('float32')
                locations = np.tile(locations, (1, 50)) / 8
            elif self._max_q_eval_mode == 'pixel_rope':
                image = observation[0].squeeze(0).cpu().numpy()
                locations = np.transpose(np.where(np.all(
                    image > 150, axis=2))).astype('float32')
                if locations.shape[0] == 0:
                    locations = np.array([[-1, -1]], dtype='float32')
                locations = np.tile(locations, (1, 50)) / 63
            elif self._max_q_eval_mode == 'pixel_cloth':
                image = observation[0].squeeze(0).cpu().numpy()
                locations = np.transpose(np.where(np.any(
                    image < 100, axis=-1))).astype('float32')
                locations = np.tile(locations, (1, 50)) / 63
            else:
                raise Exception()

            observation_pi = self.model.forward_embedding(observation)
            observation_qs = [
                q.forward_embedding(observation) for q in self.q_models
            ]

            n_locations = len(locations)
            observation_pi_i = [
                repeat(o[[i]], [n_locations] + [1] * len(o.shape[1:]))
                for o in observation_pi
            ]
            observation_qs_i = [[
                repeat(o, [n_locations] + [1] * len(o.shape[1:]))
                for o in observation_q
            ] for observation_q in observation_qs]
            locations = torch.from_numpy(locations).to(self.device)

            if MaxQInput is None:
                MaxQInput = namedtuple('MaxQPolicyInput', fields)

            aug_observation_pi = [locations] + list(observation_pi_i)
            aug_observation_pi = MaxQInput(*aug_observation_pi)
            aug_observation_qs = [[locations] + list(observation_q_i)
                                  for observation_q_i in observation_qs_i]
            aug_observation_qs = [
                MaxQInput(*aug_observation_q)
                for aug_observation_q in aug_observation_qs
            ]

            mean, log_std = self.model.forward_output(
                aug_observation_pi)  #, prev_action, prev_reward)

            qs = [
                q.forward_output(aug_obs, mean)
                for q, aug_obs in zip(self.q_models, aug_observation_qs)
            ]
            q = torch.min(torch.stack(qs, dim=0), dim=0)[0]
            #q = q.view(batch_size, n_locations)

            values, indices = torch.topk(q,
                                         math.ceil(threshold * n_locations),
                                         dim=-1)

            # vmin, vmax = values.min(dim=-1, keepdim=True)[0], values.max(dim=-1, keepdim=True)[0]
            # values = (values - vmin) / (vmax - vmin)
            # values = F.log_softmax(values, -1)
            #
            # uniform = torch.rand_like(values)
            # uniform = torch.clamp(uniform, 1e-5, 1 - 1e-5)
            # gumbel = -torch.log(-torch.log(uniform))

            #sampled_idx = torch.argmax(values + gumbel, dim=-1)
            sampled_idx = torch.randint(high=math.ceil(threshold *
                                                       n_locations),
                                        size=(1, )).to(self.device)

            actual_idxs = indices[sampled_idx]
            #actual_idxs += (torch.arange(batch_size) * n_locations).to(self.device)

            location = locations[actual_idxs][:, :1]
            location = (location - 0.5) / 0.5
            delta = torch.tanh(mean[actual_idxs])
            action = torch.cat((location, delta), dim=-1)

            mean, log_std = mean[actual_idxs], log_std[actual_idxs]

            if no_batch:
                action = action.squeeze(0)
                mean = mean.squeeze(0)
                log_std = log_std.squeeze(0)

            dist_info = DistInfoStd(mean=mean, log_std=log_std)
            agent_info = AgentInfo(dist_info=dist_info)

            action, agent_info = buffer_to((action, agent_info), device="cpu")
            return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        [
            update_state_dict(target_q, q.state_dict(), tau)
            for target_q, q in zip(self.target_q_models, self.q_models)
        ]

    @property
    def models(self):
        return Models(pi=self.model,
                      **{f'p{i}': q
                         for i, q in enumerate(self.q_models)})

    def parameters(self):
        for model in self.models:
            yield from model.parameters()
        yield self.log_alpha

    def pi_parameters(self):
        return self.model.parameters()

    def q_parameters(self):
        return [q.parameters() for q in self.q_models]

    def train_mode(self, itr):
        super().train_mode(itr)
        [q.train() for q in self.q_models]

    def sample_mode(self, itr):
        super().sample_mode(itr)
        [q.eval() for q in self.q_models]
        if itr == 0:
            logger.log(f"Agent at itr {itr}, sample std: {self.pretrain_std}")
        if itr == self.min_itr_learn:
            logger.log(f"Agent at itr {itr}, sample std: learned.")
        std = None if itr >= self.min_itr_learn else self.pretrain_std
        self.distribution.set_std(std)  # If None: std from policy dist_info.

    def eval_mode(self, itr):
        super().eval_mode(itr)
        [q.eval() for q in self.q_models]
        self.distribution.set_std(0.)  # Deterministic (dist_info std ignored).

    def state_dict(self):
        rtn = dict(
            model=self.model.state_dict(),  # Pi model.
            alpha=self.log_alpha.data)
        rtn.update({
            f'q{i}_model': q.state_dict()
            for i, q in enumerate(self.q_models)
        })
        rtn.update({
            f'target_q{i}_model': q.state_dict()
            for i, q in enumerate(self.target_q_models)
        })
        return rtn

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.log_alpha.data = state_dict['alpha']
        print(state_dict.keys())

        [
            q.load_state_dict(state_dict[f'q{i}_model'])
            for i, q in enumerate(self.q_models)
        ]
        [
            q.load_state_dict(state_dict[f'target_q{i}_model'])
            for i, q in enumerate(self.target_q_models)
        ]
예제 #11
0
class SacAgent(BaseAgent):

    shared_pi_model = None

    def __init__(
            self,
            QModelCls=QofMuMlpModel,
            VModelCls=VMlpModel,
            PiModelCls=PiMlpModel,
            q_model_kwargs=None,
            v_model_kwargs=None,
            pi_model_kwargs=None,
            initial_q1_model_state_dict=None,
            initial_q2_model_state_dict=None,
            initial_v_model_state_dict=None,
            initial_pi_model_state_dict=None,
            action_squash=1,  # Max magnitude (or None).
            pretrain_std=5.,  # High value to make near uniform sampling.
    ):
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[256, 256])
        if v_model_kwargs is None:
            v_model_kwargs = dict(hidden_sizes=[256, 256])
        if pi_model_kwargs is None:
            pi_model_kwargs = dict(hidden_sizes=[256, 256])
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

    def initialize(self, env_spaces, share_memory=False):
        env_model_kwargs = self.make_env_to_model_kwargs(env_spaces)
        self.q1_model = self.QModelCls(**env_model_kwargs,
                                       **self.q_model_kwargs)
        self.q2_model = self.QModelCls(**env_model_kwargs,
                                       **self.q_model_kwargs)
        self.v_model = self.VModelCls(**env_model_kwargs,
                                      **self.v_model_kwargs)
        self.pi_model = self.PiModelCls(**env_model_kwargs,
                                        **self.pi_model_kwargs)
        if share_memory:
            self.pi_model.share_memory()  # Only one needed for sampling.
            self.shared_pi_model = self.pi_model
        if self.initial_q1_model_state_dict is not None:
            self.q1_model.load_state_dict(self.initial_q1_model_state_dict)
        if self.initial_q2_model_state_dict is not None:
            self.q2_model.load_state_dict(self.initial_q2_model_state_dict)
        if self.initial_v_model_state_dict is not None:
            self.v_model.load_state_dict(self.initial_v_model_state_dict)
        if self.initial_pi_model_state_dict is not None:
            self.pi_model.load_state_dict(self.initial_pi_model_state_dict)
        self.target_v_model = self.VModelCls(**env_model_kwargs,
                                             **self.v_model_kwargs)
        self.target_v_model.load_state_dict(self.v_model.state_dict())
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
        self.env_spaces = env_spaces
        self.env_model_kwargs = env_model_kwargs

    def initialize_cuda(self, cuda_idx=None, ddp=False):
        if cuda_idx is None:
            return  # CPU
        if self.shared_pi_model is not None:
            self.pi_model = self.PiModelCls(**self.env_model_kwargs,
                                            **self.pi_model_kwargs)
            self.pi_model.load_state_dict(self.shared_pi_model.state_dict())
        self.device = torch.device("cuda", index=cuda_idx)
        self.q1_model.to(self.device)
        self.q2_model.to(self.device)
        self.v_model.to(self.device)
        self.pi_model.to(self.device)
        if ddp:
            self.q1_model = DDP(self.q1_model,
                                device_ids=[cuda_idx],
                                output_device=cuda_idx)
            self.q2_model = DDP(self.q2_model,
                                device_ids=[cuda_idx],
                                output_device=cuda_idx)
            self.v_model = DDP(self.v_model,
                               device_ids=[cuda_idx],
                               output_device=cuda_idx)
            self.pi_model = DDP(self.pi_model,
                                device_ids=[cuda_idx],
                                output_device=cuda_idx)
            logger.log("Initialized DistributedDataParallel agent model "
                       f"on device: {self.device}.")
        else:
            logger.log(f"Initialized agent models on device: {self.device}.")
        self.target_v_model.to(self.device)

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    def make_env_to_model_kwargs(self, env_spaces):
        assert len(env_spaces.action.shape) == 1
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.shape[0],
        )

    def q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        q1 = self.q1_model(*model_inputs)
        q2 = self.q2_model(*model_inputs)
        return q1.cpu(), q2.cpu()

    def v(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        v = self.v_model(*model_inputs)
        return v.cpu()

    def pi(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mean, log_std = self.pi_model(*model_inputs)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action, log_pi = self.distribution.sample_loglikelihood(dist_info)
        # action = self.distribution.sample(dist_info)
        # log_pi = self.distribution.log_likelihood(action, dist_info)
        log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
        return action, log_pi, dist_info  # Action stays on device for q models.

    def target_v(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        target_v = self.target_v_model(*model_inputs)
        return target_v.cpu()

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mean, log_std = self.pi_model(*model_inputs)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        if np.any(np.isnan(action.numpy())):
            breakpoint()
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_v_model, self.v_model, tau)

    @property
    def models(self):
        return self.q1_model, self.q2_model, self.v_model, self.pi_model

    def parameters(self):
        for model in self.models:
            yield from model.parameters()

    def parameters_by_model(self):
        return (model.parameters() for model in self.models)

    def sync_shared_memory(self):
        if self.shared_pi_model is not self.pi_model:
            self.shared_pi_model.load_state_dict(self.pi_model.state_dict())

    def recv_shared_memory(self):
        if self.shared_pi_model is not self.pi_model:
            with self._rw_lock:
                self.pi_model.load_state_dict(
                    self.shared_pi_model.state_dict())

    def train_mode(self, itr):
        for model in self.models:
            model.train()
        self._mode = "train"

    def sample_mode(self, itr):
        for model in self.models:
            model.eval()
        std = None if itr >= self.min_itr_learn else self.pretrain_std
        self.distribution.set_std(std)  # If None: std from policy dist_info.
        self._mode = "sample"

    def eval_mode(self, itr):
        for model in self.models:
            model.eval()
        self.distribution.set_std(0.)  # Deterministic (dist_info std ignored).
        self._mode = "eval"

    def state_dict(self):
        return dict(
            q1_model=self.q1_model.state_dict(),
            q2_model=self.q2_model.state_dict(),
            v_model=self.v_model.state_dict(),
            pi_model=self.pi_model.state_dict(),
            v_target=self.target_v_model.state_dict(),
        )
예제 #12
0
class SafeSacAgent(SacAgent):
    """
    Agent for SAC algorithm, including action-squashing, using twin Q-values.

    Modifications:
    * prev_reward and prev_action aren't used

    Design decisions
    * The CNN parameters count as policy parameters; when updating Q1 / Q2, these are
      not updated.
    """
    def __init__(
            self,
            ModelCls=ImpalaSacModel,
            model_kwargs=None,
            initial_model_state_dict=None,
            action_squash=1.0,  # Max magnitude (or None).
            pretrain_std=0.75,  # With squash 0.75 is near uniform.
    ):
        """Saves input arguments; network defaults stored within."""
        if model_kwargs is None:
            model_kwargs = dict(hidden_sizes=[256, 256])
        super(SacAgent, self).__init__(
            ModelCls=ModelCls,
            model_kwargs=model_kwargs,
            initial_model_state_dict=initial_model_state_dict,
        )
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super(SacAgent, self).initialize(env_spaces,
                                         share_memory,
                                         global_B=global_B,
                                         env_ranks=env_ranks)

        self.target_model = self.ModelCls(**self.env_model_kwargs,
                                          **self.model_kwargs)
        self.target_model.load_state_dict(self.model.state_dict())
        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )

    def to_device(self, cuda_idx=None):
        super(SacAgent, self).to_device(cuda_idx)
        self.target_model.to(self.device)

    def data_parallel(self):
        super(SacAgent, self).data_parallel()
        DDP_WRAP = DDPC if self.device.type == "cpu" else DDP
        self.target_model = DDP_WRAP(self.target_model)

    def q(self, observation, prev_action, prev_reward, action):
        """Compute twin Q-values for state/observation and input action (with grad)."""
        model_inputs = buffer_to((observation, action), device=self.device)
        q1, q2, _ = self.model(model_inputs, "q")
        return q1.cpu(), q2.cpu()

    def target_q(self, observation, prev_action, prev_reward, action):
        """Compute twin target Q-values for state/observation and input action."""
        model_inputs = buffer_to((observation, action), device=self.device)
        target_q1, target_q2, _ = self.target_model(model_inputs, "q")
        return target_q1.cpu(), target_q2.cpu()

    def pi(self, observation, prev_action, prev_reward):
        """Compute action log-probabilities for state/observation, and
        sample new action (with grad).  Uses special ``sample_loglikelihood()``
        method of Gaussian distribution, which handles action squashing
        through this process."""
        model_inputs = buffer_to(observation, device=self.device)
        mean, log_std, _ = self.model(model_inputs, "pi")
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action, log_pi = self.distribution.sample_loglikelihood(dist_info)
        log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
        return action, log_pi, dist_info  # Action stays on device for q models.

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to(observation, device=self.device)
        mean, log_std, sym_features = self.model(model_inputs,
                                                 "pi",
                                                 extract_sym_features=True)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        action = self.distribution.sample(dist_info)
        agent_info = SafeSacAgentInfo(dist_info=dist_info,
                                      sym_features=sym_features)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_model, self.model.state_dict(), tau)

    def pi_parameters(self):
        return itertools.chain(self.model.pi.parameters(),
                               self.model.feature_extractor.parameters())

    def q1_parameters(self):
        return self.model.q1.parameters()

    def q2_parameters(self):
        return self.model.q2.parameters()

    def train_mode(self, itr):
        super(SacAgent, self).train_mode(itr)
        self.target_model.train()

    def sample_mode(self, itr):
        super(SacAgent, self).sample_mode(itr)
        self.target_model.eval()
        std = None if itr >= self.min_itr_learn else self.pretrain_std
        self.distribution.set_std(std)  # If None: std from policy dist_info.

    def eval_mode(self, itr):
        super(SacAgent, self).eval_mode(itr)
        self.target_model.eval()
        self.distribution.set_std(
            0.0)  # Deterministic (dist_info std ignored).

    def state_dict(self):
        return {
            "model": self.model.state_dict(),
            "target_model": self.target_model.state_dict(),
        }

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.target_model.load_state_dict(state_dict["target_model"])
예제 #13
0
class SacAgent(BaseAgent):

    shared_pi_model = None

    def __init__(
            self,
            ModelCls=AutoregPiMlpModel,  # Pi model.
            QModelCls=QofMuMlpModel,
            model_kwargs=None,  # Pi model.
            q_model_kwargs=None,
            initial_model_state_dict=None,  # Pi model.
            action_squash=1,  # Max magnitude (or None).
            pretrain_std=0.75,  # High value to make near uniform sampling.
    ):
        if isinstance(ModelCls, str):
            ModelCls = eval(ModelCls)
        if isinstance(ModelCls, str):
            QModelCls = eval(QModelCls)

        if model_kwargs is None:
            model_kwargs = dict(hidden_sizes=[256, 256])
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[256, 256])
        super().__init__(ModelCls=ModelCls,
                         model_kwargs=model_kwargs,
                         initial_model_state_dict=initial_model_state_dict
                         )  # For async setup.
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

        self.log_alpha = None

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        self.initial_model_state_dict = None
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict
        self.q1_model = self.QModelCls(**self.env_model_kwargs,
                                       **self.q_model_kwargs)
        self.q2_model = self.QModelCls(**self.env_model_kwargs,
                                       **self.q_model_kwargs)

        self.target_q1_model = self.QModelCls(**self.env_model_kwargs,
                                              **self.q_model_kwargs)
        self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
                                              **self.q_model_kwargs)
        self.target_q1_model.load_state_dict(self.q1_model.state_dict())
        self.target_q2_model.load_state_dict(self.q2_model.state_dict())

        self.log_alpha = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))

        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)
        self.q1_model.to(self.device)
        self.q2_model.to(self.device)
        self.target_q1_model.to(self.device)
        self.target_q2_model.to(self.device)
        self.log_alpha.to(self.device)

    def data_parallel(self):
        super().data_parallel()
        DDP_WRAP = DDPC if self.device.type == "cpu" else DDP
        self.q1_model = DDP_WRAP(self.q1_model)
        self.q2_model = DDP_WRAP(self.q2_model)

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    def make_env_to_model_kwargs(self, env_spaces):
        assert len(env_spaces.action.shape) == 1
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.shape[0],
        )

    def q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        q1 = self.q1_model(*model_inputs)
        q2 = self.q2_model(*model_inputs)
        return q1.cpu(), q2.cpu()

    def target_q(self, observation, prev_action, prev_reward, action):
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, action),
            device=self.device)
        q1 = self.target_q1_model(*model_inputs)
        q2 = self.target_q2_model(*model_inputs)
        return q1.cpu(), q2.cpu()

    def pi(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        actions, means, log_stds = [], [], []
        log_pi_total = 0
        self.model.start()
        while self.model.has_next():
            mean, log_std = self.model.next(actions, *model_inputs)
            dist_info = DistInfoStd(mean=mean, log_std=log_std)
            action, log_pi = self.distribution.sample_loglikelihood(dist_info)

            log_pi_total += log_pi
            actions.append(action)
            means.append(mean)
            log_stds.append(log_std)

        mean, log_std = torch.cat(means, dim=-1), torch.cat(log_stds, dim=-1)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)

        log_pi_total, dist_info = buffer_to((log_pi_total, dist_info),
                                            device="cpu")
        action = torch.cat(actions, dim=-1)
        return action, log_pi_total, dist_info  # Action stays on device for q models.

    def target_v(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        next_actions, next_log_pis, _ = self.pi(*model_inputs)

        q1, q2 = self.target_q(observation, prev_action, prev_reward,
                               next_actions)
        min_next_q = torch.min(q1, q2)

        target_v = min_next_q - self.log_alpha.exp().detach().cpu(
        ) * next_log_pis
        return target_v.cpu()

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        actions, means, log_stds = [], [], []
        self.model.start()
        while self.model.has_next():
            mean, log_std = self.model.next(actions, *model_inputs)
            dist_info = DistInfoStd(mean=mean, log_std=log_std)
            action = self.distribution.sample(dist_info)

            actions.append(action)
            means.append(mean)
            log_stds.append(log_std)

        mean, log_std = torch.cat(means, dim=-1), torch.cat(log_stds, dim=-1)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        agent_info = AgentInfo(dist_info=dist_info)

        action = torch.cat(actions, dim=-1)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_q1_model, self.q1_model.state_dict(),
                          tau)
        update_state_dict(self.target_q2_model, self.q2_model.state_dict(),
                          tau)

    @property
    def models(self):
        return Models(pi=self.model, q1=self.q1_model, q2=self.q2_model)

    def parameters(self):
        for model in self.models:
            yield from model.parameters()
        yield self.log_alpha

    def pi_parameters(self):
        return self.model.parameters()

    def q1_parameters(self):
        return self.q1_model.parameters()

    def q2_parameters(self):
        return self.q2_model.parameters()

    def train_mode(self, itr):
        super().train_mode(itr)
        self.q1_model.train()
        self.q2_model.train()

    def sample_mode(self, itr):
        super().sample_mode(itr)
        self.q1_model.eval()
        self.q2_model.eval()
        if itr == 0:
            logger.log(f"Agent at itr {itr}, sample std: {self.pretrain_std}")
        if itr == self.min_itr_learn:
            logger.log(f"Agent at itr {itr}, sample std: learned.")
        std = None if itr >= self.min_itr_learn else self.pretrain_std
        self.distribution.set_std(std)  # If None: std from policy dist_info.

    def eval_mode(self, itr):
        super().eval_mode(itr)
        self.q1_model.eval()
        self.q2_model.eval()
        self.distribution.set_std(0.)  # Deterministic (dist_info std ignored).

    def state_dict(self):
        return dict(
            model=self.model.state_dict(),  # Pi model.
            q1_model=self.q1_model.state_dict(),
            q2_model=self.q2_model.state_dict(),
            target_q1_model=self.target_q1_model.state_dict(),
            target_q2_model=self.target_q2_model.state_dict(),
            alpha=self.log_alpha.data)

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.q1_model.load_state_dict(state_dict["q1_model"])
        self.q2_model.load_state_dict(state_dict["q2_model"])
        self.target_q1_model.load_state_dict(state_dict['target_q1_model'])
        self.target_q2_model.load_state_dict(state_dict['target_q2_model'])
        self.log_alpha.data = state_dict['alpha']