Beispiel #1
0
    def optim_initialize(self, rank=0):
        """Called in initilize or by async runner after forking sampler."""
        self.rank = rank
        self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        if self.fixed_alpha is None:
            self.target_entropy = -np.log(
                (1.0 / self.agent.env_spaces.action.n)) * 0.98
            self._log_alpha = torch.zeros(1, requires_grad=True)
            self._alpha = self._log_alpha.exp()
            self.alpha_optimizer = self.OptimCls((self._log_alpha, ),
                                                 lr=self.learning_rate,
                                                 **self.optim_kwargs)
        else:
            self._log_alpha = torch.tensor([np.log(self.fixed_alpha)])
            self._alpha = torch.tensor([self.fixed_alpha])
            self.alpha_optimizer = None
        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(self.agent.env_spaces.action.n)

        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(dim=np.prod(
                self.agent.env_spaces.action.shape),
                                                      std=1.)
Beispiel #2
0
    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        self.initial_model_state_dict = None  #! Don't let base agent try to load.
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict

        self.q_model = self.QModelCls(**self.env_model_kwargs,
                                      **self.q_model_kwargs)
        self.target_q_model = self.QModelCls(**self.env_model_kwargs,
                                             **self.q_model_kwargs)
        self.target_q_model.load_state_dict(self.q_model.state_dict())

        if self.initial_model_state_dict is not None and not self.load_model_after_min_steps:
            self.load_state_dict(self.initial_model_state_dict)

        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
        # Tie weights (need to make sure False if not using encoder)
        if self.tie_weights:
            self.model.encoder.copy_conv_weights_from(self.q_model.encoder)
Beispiel #3
0
 def initialize(self, env_spaces, share_memory=False):
     env_model_kwargs = self.make_env_to_model_kwargs(env_spaces)
     self.q1_model = self.QModelCls(**env_model_kwargs,
                                    **self.q_model_kwargs)
     self.q2_model = self.QModelCls(**env_model_kwargs,
                                    **self.q_model_kwargs)
     self.v_model = self.VModelCls(**env_model_kwargs,
                                   **self.v_model_kwargs)
     self.pi_model = self.PiModelCls(**env_model_kwargs,
                                     **self.pi_model_kwargs)
     if share_memory:
         self.pi_model.share_memory()  # Only one needed for sampling.
         self.shared_pi_model = self.pi_model
     if self.initial_q1_model_state_dict is not None:
         self.q1_model.load_state_dict(self.initial_q1_model_state_dict)
     if self.initial_q2_model_state_dict is not None:
         self.q2_model.load_state_dict(self.initial_q2_model_state_dict)
     if self.initial_v_model_state_dict is not None:
         self.v_model.load_state_dict(self.initial_v_model_state_dict)
     if self.initial_pi_model_state_dict is not None:
         self.pi_model.load_state_dict(self.initial_pi_model_state_dict)
     self.target_v_model = self.VModelCls(**env_model_kwargs,
                                          **self.v_model_kwargs)
     self.target_v_model.load_state_dict(self.v_model.state_dict())
     assert len(env_spaces.action.shape) == 1
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         squash=self.action_squash,
         min_std=np.exp(MIN_LOG_STD),
         max_std=np.exp(MAX_LOG_STD),
     )
     self.env_spaces = env_spaces
     self.env_model_kwargs = env_model_kwargs
Beispiel #4
0
    def __init__(
            self,
            observation_shape,
            hidden_sizes,
            action_size,
            n_tile=20,
    ):
        super().__init__()
        self._obs_ndim = 1
        self._n_tile = n_tile
        input_dim = int(np.sum(observation_shape))

        self._action_size = action_size
        self.mlp_loc = MlpModel(
            input_size=input_dim,
            hidden_sizes=hidden_sizes,
            output_size=4
        )
        self.mlp_delta = MlpModel(
            input_size=input_dim + 4 * n_tile,
            hidden_sizes=hidden_sizes,
            output_size=3 * 2,
        )

        self.delta_distribution = Gaussian(
            dim=3,
            squash=True,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
        self.cat_distribution = Categorical(4)

        self._counter = 0
Beispiel #5
0
 def initialize(self, env_spaces, share_memory=False,
                global_B=1, env_ranks=None):
     _initial_model_state_dict = self.initial_model_state_dict
     # Don't let base agent try to load.
     self.initial_model_state_dict = None
     super().initialize(env_spaces, share_memory,
                        global_B=global_B, env_ranks=env_ranks)
     self.initial_model_state_dict = _initial_model_state_dict
     self.q1_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
     self.q2_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
     self.target_q1_model = self.QModelCls(**self.env_model_kwargs,
                                           **self.q_model_kwargs)
     self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
                                           **self.q_model_kwargs)
     self.target_q1_model.load_state_dict(self.q1_model.state_dict())
     self.target_q2_model.load_state_dict(self.q2_model.state_dict())
     if self.initial_model_state_dict is not None:
         self.load_state_dict(self.initial_model_state_dict)
     assert len(env_spaces.action.shape) == 1
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         squash=self.action_squash,
         min_std=np.exp(MIN_LOG_STD),
         max_std=np.exp(MAX_LOG_STD),
     )
Beispiel #6
0
    def initialize(self, env_spaces, share_memory=False,
                   global_B=1, env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        self.initial_model_state_dict = None
        super().initialize(env_spaces, share_memory,
                           global_B=global_B, env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict
        self.q_models = [self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
                         for _ in range(self.n_qs)]

        self.target_q_models = [self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
                                for _ in range(self.n_qs)]
        [target_q.load_state_dict(q.state_dict())
         for target_q, q in zip(self.target_q_models, self.q_models)]

        self.log_alpha = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))

        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
 def optim_initialize(self, rank=0):
     """Called by async runner."""
     self.rank = rank
     self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
                                       lr=self.learning_rate,
                                       **self.optim_kwargs)
     self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(),
                                       lr=self.learning_rate,
                                       **self.optim_kwargs)
     self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(),
                                       lr=self.learning_rate,
                                       **self.optim_kwargs)
     self._log_alpha = torch.zeros(1, requires_grad=True)
     self._alpha = torch.exp(self._log_alpha.detach())
     self.alpha_optimizer = self.OptimCls((self._log_alpha, ),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
     if self.target_entropy == "auto":
         self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)
     if self.initial_optim_state_dict is not None:
         self.load_optim_state_dict(self.initial_optim_state_dict)
     if self.action_prior == "gaussian":
         self.action_prior_distribution = Gaussian(dim=np.prod(
             self.agent.env_spaces.action.shape),
                                                   std=1.)
Beispiel #8
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     """Instantiates mu and q, and target_mu and target_q models."""
     super().initialize(env_spaces,
                        share_memory,
                        global_B=global_B,
                        env_ranks=env_ranks)
     self.q_model = self.QModelCls(**self.env_model_kwargs,
                                   **self.q_model_kwargs)
     if self.initial_q_model_state_dict is not None:
         self.q_model.load_state_dict(self.initial_q_model_state_dict)
     self.target_model = self.ModelCls(**self.env_model_kwargs,
                                       **self.model_kwargs)
     self.target_q_model = self.QModelCls(**self.env_model_kwargs,
                                          **self.q_model_kwargs)
     self.target_q_model.load_state_dict(self.q_model.state_dict())
     assert len(env_spaces.action.shape) == 1
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         std=self.action_std,
         noise_clip=self.action_noise_clip,
         clip=env_spaces.action.high[0],  # Assume symmetric low=-high.
     )
Beispiel #9
0
    def __init__(
            self,
            observation_shape,
            hidden_sizes,
            action_size,
            all_corners=False
            ):
        super().__init__()
        self._obs_ndim = 1
        self._all_corners = all_corners
        input_dim = int(np.sum(observation_shape))

        print('all corners', self._all_corners)
        delta_dim = 12 if all_corners else 3
        self._delta_dim = delta_dim
        self.mlp = MlpModel(
            input_size=input_dim,
            hidden_sizes=hidden_sizes,
            output_size=2 * delta_dim + 4, # 3 for each corners, times two for std, 4 probs
        )

        self.delta_distribution = Gaussian(
            dim=delta_dim,
            squash=True,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
        self.cat_distribution = Categorical(4)
Beispiel #10
0
 def initialize(self, env_spaces, share_memory=False):
     env_model_kwargs = self.make_env_to_model_kwargs(env_spaces)
     self.mu_model = self.MuModelCls(**env_model_kwargs,
                                     **self.mu_model_kwargs)
     self.q_model = self.QModelCls(**env_model_kwargs,
                                   **self.q_model_kwargs)
     if share_memory:
         self.mu_model.share_memory()
         # self.q_model.share_memory()  # Not needed for sampling.
         self.shared_mu_model = self.mu_model
         # self.shared_q_model = self.q_model
     if self.initial_mu_model_state_dict is not None:
         self.mu_model.load_state_dict(self.initial_mu_model_state_dict)
     if self.initial_q_model_state_dict is not None:
         self.q_model.load_state_dict(self.initial_q_model_state_dict)
     self.target_mu_model = self.MuModelCls(**env_model_kwargs,
                                            **self.mu_model_kwargs)
     self.target_mu_model.load_state_dict(self.mu_model.state_dict())
     self.target_q_model = self.QModelCls(**env_model_kwargs,
                                          **self.q_model_kwargs)
     self.target_q_model.load_state_dict(self.q_model.state_dict())
     assert len(env_spaces.action.shape) == 1
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         std=self.action_std,
         noise_clip=self.action_noise_clip,
         clip=env_spaces.action.high[0],  # Assume symmetric low=-high.
     )
     self.env_spaces = env_spaces
     self.env_model_kwargs = env_model_kwargs
Beispiel #11
0
 def initialize(self, env_spaces, share_memory=False):
     super().initialize(env_spaces, share_memory)
     assert len(env_spaces.action.shape == 1)
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         # min_std=MIN_STD,
         # clip=env_spaces.action.high[0],  # Probably +1?
     )
Beispiel #12
0
 def initialize(self, env_spaces, share_memory=False,
         global_B=1, env_ranks=None):
     super().initialize(env_spaces, share_memory,
         global_B=global_B, env_ranks=env_ranks)
     assert len(env_spaces.action.shape) == 1
     # assert len(np.unique(env_spaces.action.high)) == 1
     # assert np.all(env_spaces.action.low == -env_spaces.action.high)
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         # min_std=MIN_STD,
         # clip=env_spaces.action.high[0],  # Probably +1?
     )
Beispiel #13
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     super().initialize(env_spaces, share_memory)
     assert len(env_spaces.action.shape) == 1
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         # min_std=MIN_STD,
         clip=env_spaces.action.high[0],  # Probably +1?
     )
Beispiel #14
0
 def optim_initialize(self, rank=0):
     """Called by async runner."""
     self.rank = rank
     self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
         lr=self.learning_rate, **self.optim_kwargs)
     self.q_optimizers = [self.OptimCls(q_param)
                          for q_param in self.agent.q_parameters()]
     self.alpha_optimizer = self.OptimCls([self.agent.log_alpha],
         lr=self.learning_rate, **self.optim_kwargs)
     if self.initial_optim_state_dict is not None:
         self.pi_optimizer.load_state_dict(self.initial_optim_state_dict)
     if self.action_prior == "gaussian":
         self.action_prior_distribution = Gaussian(
             dim=self.agent.env_spaces.action.size, std=1.)
Beispiel #15
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     super().initialize(env_spaces, share_memory)
     assert len(env_spaces.action.shape) == 1
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         # min_std=MIN_STD,
         # clip=env_spaces.action.high[0],  # Probably +1?
     )
     self.distribution_omega = Categorical(
         dim=self.model_kwargs["option_size"])
Beispiel #16
0
 def initialize(self, env_spaces, share_memory=False, global_B=1, env_ranks=None):
     super().initialize(env_spaces, share_memory, global_B, env_ranks)
     self.q2_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
     if self.initial_q2_model_state_dict is not None:
         self.q2_model.load_state_dict(self.initial_q2_model_state_dict)
     self.target_q2_model = self.QModelCls(
         **self.env_model_kwargs, **self.q_model_kwargs
     )
     self.target_q2_model.load_state_dict(self.q2_model.state_dict())
     self.target_distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         std=self.target_noise_std,
         noise_clip=self.target_noise_clip,
         clip=env_spaces.action.high[0],  # Assume symmetric low=-high.
     )
Beispiel #17
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     """Extends base method to build Gaussian distribution."""
     if (not (env_spaces.action.high == 1).all()
             and (env_spaces.action.low == -1).all()):
         raise ValueError(f"The space for all actions should be [-1, 1].")
     super().initialize(env_spaces,
                        share_memory,
                        global_B=global_B,
                        env_ranks=env_ranks)
     self.distribution = Gaussian(dim=env_spaces.action.shape[0],
                                  min_std=1e-6,
                                  max_std=1)
Beispiel #18
0
    def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples):
        if agent.recurrent:
            raise NotImplementedError
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.optimizer = self.OptimCls(agent.parameters(),
                                       lr=self.learning_rate,
                                       **self.optim_kwargs)
        if self.initial_optim_state_dict is not None:
            self.optimizer.load_state_dict(self.initial_optim_state_dict)

        sample_bs = batch_spec.size
        train_bs = self.batch_size
        assert (self.training_ratio * sample_bs) % train_bs == 0
        self.updates_per_optimize = int(
            (self.training_ratio * sample_bs) // train_bs)
        logger.log(
            f"From sampler batch size {sample_bs}, training "
            f"batch size {train_bs}, and training ratio "
            f"{self.training_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = self.min_steps_learn // sample_bs
        self.agent.give_min_itr_learn(self.min_itr_learn)

        example_to_buffer = SamplesToBuffer(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
        )
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        self.replay_buffer = UniformReplayBuffer(**replay_kwargs)

        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(
                dim=agent.env_spaces.action.size, std=1.)
Beispiel #19
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     """Extends base method to build Gaussian distribution."""
     super().initialize(env_spaces,
                        share_memory,
                        global_B=global_B,
                        env_ranks=env_ranks)
     assert len(env_spaces.action.shape) == 1
     assert len(np.unique(env_spaces.action.high)) == 1
     assert np.all(env_spaces.action.low == -env_spaces.action.high)
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         # min_std=MIN_STD,
         # clip=env_spaces.action.high[0],  # Probably +1?
     )
     self.distribution_omega = Categorical(
         dim=self.model_kwargs["option_size"])
Beispiel #20
0
    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super(SacAgent, self).initialize(env_spaces,
                                         share_memory,
                                         global_B=global_B,
                                         env_ranks=env_ranks)

        self.target_model = self.ModelCls(**self.env_model_kwargs,
                                          **self.model_kwargs)
        self.target_model.load_state_dict(self.model.state_dict())
        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
Beispiel #21
0
    def optim_initialize(self, rank=0):
        """Called in initilize or by async runner after forking sampler."""
        self.rank = rank

        # Be very explicit about which parameters are optimized where.
        self.pi_optimizer = self.OptimCls(
            chain(
                self.agent.pi_fc1.parameters(),  # No conv.
                self.agent.pi_mlp.parameters(),
            ),
            lr=self.pi_lr,
            betas=(self.pi_beta, 0.999),
        )
        self.q_optimizer = self.OptimCls(
            chain(
                () if self.stop_conv_grad else self.agent.conv.parameters(),
                self.agent.q_fc1.parameters(),
                self.agent.q_mlps.parameters(),
            ),
            lr=self.q_lr,
            betas=(self.q_beta, 0.999),
        )

        self._log_alpha = torch.tensor(np.log(self.alpha_init), requires_grad=True)
        self._alpha = torch.exp(self._log_alpha.detach())
        self.alpha_optimizer = self.OptimCls(
            (self._log_alpha,), lr=self.alpha_lr, betas=(self.alpha_beta, 0.999)
        )

        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)
        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(
                dim=np.prod(self.agent.env_spaces.action.shape), std=1.0
            )
Beispiel #22
0
    def optim_initialize(self, rank=0):
        """Called in initilize or by async runner after forking sampler."""
        self.rank = rank

        # Be very explicit about which parameters are optimized where.
        self.pi_optimizer = self.OptimCls(
            chain(
                self.agent.pi_fc1.parameters(),  # No conv.
                self.agent.pi_mlp.parameters(),
            ),
            lr=self.pi_lr,
            betas=(self.pi_beta, 0.999),
        )
        self.q_optimizer = self.OptimCls(
            chain(
                () if self.stop_rl_conv_grad else self.agent.conv.parameters(),
                self.agent.q_fc1.parameters(),
                self.agent.q_mlps.parameters(),
            ),
            lr=self.q_lr,
            betas=(self.q_beta, 0.999),
        )

        self._log_alpha = torch.tensor(np.log(self.alpha_init),
                                       requires_grad=True)
        self._alpha = torch.exp(self._log_alpha.detach())
        self.alpha_optimizer = self.OptimCls((self._log_alpha, ),
                                             lr=self.alpha_lr,
                                             betas=(self.alpha_beta, 0.999))

        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)
        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(dim=np.prod(
                self.agent.env_spaces.action.shape),
                                                      std=1.0)

        self.ul_optimizer = self.OptimCls(self.ul_parameters(),
                                          lr=self.ul_learning_rate,
                                          **self.ul_optim_kwargs)

        self.total_ul_updates = sum([
            self.compute_ul_update_schedule(itr) for itr in range(self.n_itr)
        ])
        logger.log(
            f"Total number of UL updates to do: {self.total_ul_updates}.")
        self.ul_update_counter = 0
        self.ul_lr_scheduler = None
        if self.total_ul_updates > 0:
            if self.ul_lr_schedule == "linear":
                self.ul_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
                    optimizer=self.ul_optimizer,
                    lr_lambda=lambda upd:
                    (self.total_ul_updates - upd) / self.total_ul_updates,
                )
            elif self.ul_lr_schedule == "cosine":
                self.ul_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer=self.ul_optimizer,
                    T_max=self.total_ul_updates - self.ul_lr_warmup,
                )
            elif self.ul_lr_schedule is not None:
                raise NotImplementedError

            if self.ul_lr_warmup > 0:
                self.ul_lr_scheduler = GradualWarmupScheduler(
                    self.ul_optimizer,
                    multiplier=1,
                    total_epoch=self.ul_lr_warmup,  # actually n_updates
                    after_scheduler=self.ul_lr_scheduler,
                )

            if self.ul_lr_scheduler is not None:
                self.ul_optimizer.zero_grad()
                self.ul_optimizer.step()

            self.c_e_loss = torch.nn.CrossEntropyLoss(
                ignore_index=IGNORE_INDEX)
Beispiel #23
0
    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        self.conv = self.ConvModelCls(image_shape=env_spaces.observation.shape,
                                      **self.conv_kwargs)
        self.q_fc1 = self.Fc1ModelCls(input_size=self.conv.output_size,
                                      **self.fc1_kwargs)
        self.pi_fc1 = self.Fc1ModelCls(input_size=self.conv.output_size,
                                       **self.fc1_kwargs)

        latent_size = self.q_fc1.output_size
        action_size = env_spaces.action.shape[0]

        # These are just MLPs
        self.pi_mlp = self.PiModelCls(input_size=latent_size,
                                      action_size=action_size,
                                      **self.pi_model_kwargs)
        self.q_mlps = self.QModelCls(input_size=latent_size,
                                     action_size=action_size,
                                     **self.q_model_kwargs)
        self.target_q_mlps = copy.deepcopy(self.q_mlps)  # Separate params.

        # Make reference to the full actor model including encoder.
        # CAREFUL ABOUT TRAIN MODE FOR LAYER NORM IF CHANGING THIS?
        self.model = SacModel(conv=self.conv,
                              pi_fc1=self.pi_fc1,
                              pi_mlp=self.pi_mlp)

        if self.load_conv:
            logger.log("Agent loading state dict: " + self.state_dict_filename)
            loaded_state_dict = torch.load(self.state_dict_filename,
                                           map_location=torch.device("cpu"))
            # From UL, saves snapshot: params["algo_state_dict"]["encoder"]
            if "algo_state_dict" in loaded_state_dict:
                loaded_state_dict = loaded_state_dict
            loaded_state_dict = loaded_state_dict.get("algo_state_dict",
                                                      loaded_state_dict)
            loaded_state_dict = loaded_state_dict.get("encoder",
                                                      loaded_state_dict)
            # A bit onerous, but ensures that state dicts match:
            conv_state_dict = OrderedDict([
                (k, v)  # .replace("conv.", "", 1)
                for k, v in loaded_state_dict.items() if k.startswith("conv.")
            ])
            self.conv.load_state_dict(conv_state_dict)
            # Double check it gets into the q_encoder as well.
            logger.log("Agent loaded CONV state dict.")
        elif self.load_all:
            # From RL, saves snapshot: params["agent_state_dict"]
            loaded_state_dict = torch.load(self.state_dict_filename,
                                           map_location=torch.device("cpu"))
            self.load_state_dict(loaded_state_dict["agent_state_dict"])
            logger.log("Agnet loaded FULL state dict.")
        else:
            logger.log("Agent NOT loading state dict.")

        self.target_conv = copy.deepcopy(self.conv)
        self.target_q_fc1 = copy.deepcopy(self.q_fc1)

        if share_memory:
            # The actor model needs to share memory to sampler workers, and
            # this includes handling the encoder!
            # (Almost always just run serial anyway, no sharing.)
            self.model.share_memory()
            self.shared_model = self.model
        if self.initial_state_dict is not None:
            raise NotImplementedError
        self.env_spaces = env_spaces
        self.share_memory = share_memory

        assert len(env_spaces.action.shape) == 1
        self.distribution = Gaussian(
            dim=env_spaces.action.shape[0],
            squash=self.action_squash,
            # min_std=np.exp(MIN_LOG_STD),  # NOPE IN PI_MODEL NOW
            # max_std=np.exp(MAX_LOG_STD),
        )