Esempio n. 1
0
    def initialize(self, n_updates, cuda_idx=None):
        self.device = torch.device("cpu") if cuda_idx is None else torch.device(
            "cuda", index=cuda_idx)

        examples = self.load_replay()
        self.encoder = self.EncoderCls(
            image_shape=examples.observation.shape,
            latent_size=self.latent_size,  # UNUSED
            **self.encoder_kwargs
        )
        if self.onehot_action:
            act_dim = self.replay_buffer.samples.action.max() + 1  # discrete only
            self.distribution = Categorical(act_dim)
        else:
            act_shape = self.replay_buffer.samples.action.shape[2:]
            assert len(act_shape) == 1
            act_dim = act_shape[0]
        self.vae_head = self.VaeHeadCls(
            latent_size=self.latent_size,
            action_size=act_dim * self.delta_T,
            hidden_sizes=self.hidden_sizes,
        )
        self.decoder = self.DecoderCls(
            latent_size=self.latent_size,
            **self.decoder_kwargs
        )
        self.encoder.to(self.device)
        self.vae_head.to(self.device)
        self.decoder.to(self.device)

        self.optim_initialize(n_updates)

        if self.initial_state_dict is not None:
            self.load_state_dict(self.initial_state_dict)
Esempio n. 2
0
    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)

        self.augs_funcs = OrderedDict()
        aug_to_func = {
            'crop': rad.random_crop,
            'crop_horiz': rad.random_crop_horizontile,
            'grayscale': rad.random_grayscale,
            'cutout': rad.random_cutout,
            'cutout_color': rad.random_cutout_color,
            'flip': rad.random_flip,
            'rotate': rad.random_rotation,
            'rand_conv': rad.random_convolution,
            'color_jitter': rad.random_color_jitter,
            'no_aug': rad.no_aug,
        }

        if self.data_augs == "":
            aug_names = []
        else:
            aug_names = self.data_augs.split('-')
        for aug_name in aug_names:
            assert aug_name in aug_to_func, 'invalid data aug string'
            self.augs_funcs[aug_name] = aug_to_func[aug_name]
Esempio n. 3
0
    def initialize(self, n_updates, cuda_idx=None):
        self.device = torch.device(
            "cpu") if cuda_idx is None else torch.device("cuda",
                                                         index=cuda_idx)

        examples = self.load_replay()
        self.encoder = self.EncoderCls(
            image_shape=examples.observation.shape,
            latent_size=10,  # UNUSED
            **self.encoder_kwargs)

        if self.onehot_actions:
            act_dim = self.replay_buffer.samples.action.max() + 1
            self.distribution = Categorical(act_dim)
        else:
            assert len(self.replay_buffer.samples.action.shape == 3)
            act_dim = self.replay_buffer.samples.action.shape[2]
        self.inverse_model = self.InverseModelCls(
            input_size=self.encoder.conv_out_size,
            action_size=act_dim,
            num_actions=self.delta_T,
            use_input="conv",
            **self.inverse_model_kwargs)
        self.encoder.to(self.device)
        self.inverse_model.to(self.device)

        self.optim_initialize(n_updates)

        if self.initial_state_dict is not None:
            self.load_state_dict(self.initial_state_dict)
Esempio n. 4
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     _initial_model_state_dict = self.initial_model_state_dict
     self.initial_model_state_dict = None  # Don't let base agent try to load.
     super().initialize(env_spaces,
                        share_memory,
                        global_B=global_B,
                        env_ranks=env_ranks)
     self.initial_model_state_dict = _initial_model_state_dict
     self.q1_model = self.QModelCls(**self.env_model_kwargs,
                                    **self.q_model_kwargs)
     self.q2_model = self.QModelCls(**self.env_model_kwargs,
                                    **self.q_model_kwargs)
     self.target_q1_model = self.QModelCls(**self.env_model_kwargs,
                                           **self.q_model_kwargs)
     self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
                                           **self.q_model_kwargs)
     self.target_q1_model.load_state_dict(self.q1_model.state_dict())
     self.target_q2_model.load_state_dict(self.q2_model.state_dict())
     if self.initial_model_state_dict is not None:
         self.load_state_dict(self.initial_model_state_dict)
     # assert len(env_spaces.action.shape) == 1
     # self.distribution = Gaussian(
     #     dim=env_spaces.action.shape[0],
     #     squash=self.action_squash,
     #     min_std=np.exp(MIN_LOG_STD),
     #     max_std=np.exp(MAX_LOG_STD),
     # )
     self.distribution = Categorical(dim=env_spaces.action.n)
Esempio n. 5
0
class CategoricalPgAgent(BaseAgent):

    def __call__(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        pi, value = self.model(*model_inputs)
        return buffer_to((DistInfo(prob=pi), value), device="cpu")

    def initialize(self, env_spaces, share_memory=False,
            global_B=1, env_ranks=None):
        super().initialize(env_spaces, share_memory,
            global_B=global_B, env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        pi, value = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info, value=value)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        _pi, value = self.model(*model_inputs)
        return value.to("cpu")
Esempio n. 6
0
 def initialize(self, env_spaces, share_memory=False,
         global_B=1, env_ranks=None):
     super().initialize(env_spaces, share_memory,
         global_B=global_B, env_ranks=env_ranks)
     if self.override['override_policy_value']:
         policy_layers=self.override["policy_layers"]
         value_layers=self.override["value_layers"]
         self.model.override_policy_value(policy_layers=policy_layers,
             value_layers=value_layers)
     self.distribution = Categorical(dim=env_spaces.action.n)
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     super().initialize(env_spaces,
                        share_memory,
                        global_B=global_B,
                        env_ranks=env_ranks)
     self.distribution = Categorical(dim=env_spaces.action.n)
Esempio n. 8
0
class RecurrentCategoricalPgAgentBase(BaseAgent):
    def __call__(self,
                 observation,
                 prev_action,
                 prev_reward,
                 init_rnn_state,
                 device="cpu"):
        # Assume init_rnn_state already shaped: [N,B,H]
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, init_rnn_state),
            device=self.device)
        pi, value, next_rnn_state = self.model(*model_inputs)
        dist_info, value = buffer_to((DistInfo(prob=pi), value), device=device)
        return dist_info, value, next_rnn_state  # Leave rnn_state on device.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward, device="cpu"):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        pi, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
            rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
        agent_info = AgentInfoRnn(dist_info=dist_info,
                                  value=value,
                                  prev_rnn_state=prev_rnn_state)
        action, agent_info = buffer_to((action, agent_info), device=device)
        self.advance_rnn_state(rnn_state)  # Keep on device.
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward, device="cpu"):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _pi, value, _rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        return value.to(device)
Esempio n. 9
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     super().initialize(env_spaces,
                        share_memory,
                        global_B=global_B,
                        env_ranks=env_ranks)
     self.distribution = Categorical(dim=env_spaces.action.n)
     self.distribution_omega = Categorical(
         dim=self.model_kwargs["option_size"])
Esempio n. 10
0
    def __init__(
            self,
            observation_shape,
            hidden_sizes,
            action_size,
            all_corners=False
            ):
        super().__init__()
        self._obs_ndim = 1
        self._all_corners = all_corners
        input_dim = int(np.sum(observation_shape))

        print('all corners', self._all_corners)
        delta_dim = 12 if all_corners else 3
        self._delta_dim = delta_dim
        self.mlp = MlpModel(
            input_size=input_dim,
            hidden_sizes=hidden_sizes,
            output_size=2 * delta_dim + 4, # 3 for each corners, times two for std, 4 probs
        )

        self.delta_distribution = Gaussian(
            dim=delta_dim,
            squash=True,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
        self.cat_distribution = Categorical(4)
Esempio n. 11
0
    def __init__(
            self,
            observation_shape,
            hidden_sizes,
            action_size,
            n_tile=20,
    ):
        super().__init__()
        self._obs_ndim = 1
        self._n_tile = n_tile
        input_dim = int(np.sum(observation_shape))

        self._action_size = action_size
        self.mlp_loc = MlpModel(
            input_size=input_dim,
            hidden_sizes=hidden_sizes,
            output_size=4
        )
        self.mlp_delta = MlpModel(
            input_size=input_dim + 4 * n_tile,
            hidden_sizes=hidden_sizes,
            output_size=3 * 2,
        )

        self.delta_distribution = Gaussian(
            dim=3,
            squash=True,
            min_std=np.exp(MIN_LOG_STD),
            max_std=np.exp(MAX_LOG_STD),
        )
        self.cat_distribution = Categorical(4)

        self._counter = 0
Esempio n. 12
0
class CategoricalPgAgent(BaseAgent):
    """
    Agent for policy gradient algorithm using categorical action distribution.
    Same as ``GausssianPgAgent`` and related classes, except uses
    ``Categorical`` distribution, and has a different interface to the model
    (model here outputs discrete probabilities rather than means and
    log_stds)...maybe could reorganize those interfaces to reduce to one
    poligy gradient agent class.
    """
    def __call__(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        pi, value = self.model(*model_inputs)
        return buffer_to((DistInfo(prob=pi), value), device="cpu")

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        pi, value = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info, value=value)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _pi, value = self.model(*model_inputs)
        return value.to("cpu")
Esempio n. 13
0
    def initialize(self, env_spaces, share_memory=False, global_B=1, env_ranks=None):
        self.model = self.ModelCls(
            image_shape=env_spaces.observation.shape,
            output_size=env_spaces.action.n,
            **self.model_kwargs
        )  # Model will have stop_grad inside it.
        if self.load_conv:
            logger.log("Agent loading state dict: " + self.state_dict_filename)
            loaded_state_dict = torch.load(
                self.state_dict_filename, map_location=torch.device("cpu")
            )
            # From UL, saves snapshot: params["algo_state_dict"]["encoder"]
            loaded_state_dict = loaded_state_dict.get(
                "algo_state_dict", loaded_state_dict
            )
            loaded_state_dict = loaded_state_dict.get("encoder", loaded_state_dict)
            # A bit onerous, but ensures that state dicts match:
            conv_state_dict = OrderedDict(
                [
                    (k.replace("conv.", "", 1), v)
                    for k, v in loaded_state_dict.items()
                    if k.startswith("conv.")
                ]
            )
            self.model.conv.load_state_dict(conv_state_dict)
            logger.log("Agent loaded CONV state dict.")
        elif self.load_all:
            # From RL, saves snapshot: params["agent_state_dict"]
            loaded_state_dict = torch.load(
                self.state_dict_filename, map_location=torch.device("cpu")
            )
            self.load_state_dict(loaded_state_dict["agent_state_dict"])
            logger.log("Agnet loaded FULL state dict.")
        else:
            logger.log("Agent NOT loading state dict.")

        if share_memory:
            self.model.share_memory()
            self.shared_model = self.model
        if self.initial_model_state_dict is not None:
            raise NotImplementedError
        self.distribution = Categorical(dim=env_spaces.action.n)
        self.env_spaces = env_spaces
        self.share_memory = share_memory
Esempio n. 14
0
    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        self.initial_model_state_dict = None  # Don't let base agent try to load.
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict

        self.critic = self.CriticCls(**self.env_model_kwargs,
                                     EncoderCls=self.EncoderCls,
                                     encoder_kwargs=self.encoder_kwargs,
                                     **self.critic_kwargs)
        self.target_model = self.CriticCls(**self.env_model_kwargs,
                                           EncoderCls=self.EncoderCls,
                                           encoder_kwargs=self.encoder_kwargs,
                                           **self.critic_kwargs)
        self.decoder = self.DecoderCls(**self.encoder_kwargs)
        # self.q1_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
        # self.q2_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
        # self.target_q1_model = self.QModelCls(**self.env_model_kwargs,
        #     **self.q_model_kwargs)
        # self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
        #     **self.q_model_kwargs)
        self.target_model.load_state_dict(self.critic.state_dict())
        # Tie the Encoder of the actor to that of the critic
        self.model.encoder.copy_weights_from(self.critic.encoder)

        # self.target_q1_model.load_state_dict(self.q1_model.state_dict())
        # self.target_q2_model.load_state_dict(self.q2_model.state_dict())
        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        # assert len(env_spaces.action.shape) == 1
        # self.distribution = Gaussian(
        #     dim=env_spaces.action.shape[0],
        #     squash=self.action_squash,
        #     min_std=np.exp(MIN_LOG_STD),
        #     max_std=np.exp(MAX_LOG_STD),
        # )
        self.distribution = Categorical(dim=env_spaces.action.n)
Esempio n. 15
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     super().initialize(env_spaces, share_memory)
     assert len(env_spaces.action.shape) == 1
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         # min_std=MIN_STD,
         # clip=env_spaces.action.high[0],  # Probably +1?
     )
     self.distribution_omega = Categorical(
         dim=self.model_kwargs["option_size"])
Esempio n. 16
0
 def initialize(self, env_spaces, share_memory=False,
         global_B=1, env_ranks=None):
     _initial_model_state_dict = self.initial_model_state_dict
     self.initial_model_state_dict = None  # Don't let base agent try to load.
     super().initialize(env_spaces, share_memory,
         global_B=global_B, env_ranks=env_ranks)
     self.initial_model_state_dict = _initial_model_state_dict
     self.q1_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
     self.q2_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
     self.target_q1_model = self.QModelCls(**self.env_model_kwargs,
         **self.q_model_kwargs)
     self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
         **self.q_model_kwargs)
     self.target_q1_model.load_state_dict(self.q1_model.state_dict())
     self.target_q2_model.load_state_dict(self.q2_model.state_dict())
     if self.initial_model_state_dict is not None:
         self.load_state_dict(self.initial_model_state_dict)
     self.distribution = Categorical(dim=env_spaces.action.n)
Esempio n. 17
0
 def initialize(self,
                env_spaces,
                share_memory=False,
                global_B=1,
                env_ranks=None):
     """Extends base method to build Gaussian distribution."""
     super().initialize(env_spaces,
                        share_memory,
                        global_B=global_B,
                        env_ranks=env_ranks)
     assert len(env_spaces.action.shape) == 1
     assert len(np.unique(env_spaces.action.high)) == 1
     assert np.all(env_spaces.action.low == -env_spaces.action.high)
     self.distribution = Gaussian(
         dim=env_spaces.action.shape[0],
         # min_std=MIN_STD,
         # clip=env_spaces.action.high[0],  # Probably +1?
     )
     self.distribution_omega = Categorical(
         dim=self.model_kwargs["option_size"])
Esempio n. 18
0
class DmlabPgBaseAgent(BaseAgent):
    """Only doing the feedforward agent for now."""

    def __init__(
        self,
        ModelCls=DmlabPgLstmModel,
        store_latent=False,
        state_dict_filename=None,
        load_conv=False,
        load_all=False,
        **kwargs
    ):
        super().__init__(ModelCls=ModelCls, **kwargs)
        self.store_latent = store_latent
        self.state_dict_filename = state_dict_filename
        self.load_conv = load_conv
        self.load_all = load_all
        assert not (load_all and load_conv)
        self._act_uniform = False

    def __call__(self, observation, prev_action, prev_reward, init_rnn_state):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, init_rnn_state), device=self.device
        )
        pi, value, next_rnn_state, _ = self.model(*model_inputs)  # Ignore conv out
        dist_info, value = buffer_to((DistInfo(prob=pi), value), device="cpu")
        return dist_info, value, next_rnn_state

    def initialize(self, env_spaces, share_memory=False, global_B=1, env_ranks=None):
        self.model = self.ModelCls(
            image_shape=env_spaces.observation.shape,
            output_size=env_spaces.action.n,
            **self.model_kwargs
        )  # Model will have stop_grad inside it.
        if self.load_conv:
            logger.log("Agent loading state dict: " + self.state_dict_filename)
            loaded_state_dict = torch.load(
                self.state_dict_filename, map_location=torch.device("cpu")
            )
            # From UL, saves snapshot: params["algo_state_dict"]["encoder"]
            loaded_state_dict = loaded_state_dict.get(
                "algo_state_dict", loaded_state_dict
            )
            loaded_state_dict = loaded_state_dict.get("encoder", loaded_state_dict)
            # A bit onerous, but ensures that state dicts match:
            conv_state_dict = OrderedDict(
                [
                    (k.replace("conv.", "", 1), v)
                    for k, v in loaded_state_dict.items()
                    if k.startswith("conv.")
                ]
            )
            self.model.conv.load_state_dict(conv_state_dict)
            logger.log("Agent loaded CONV state dict.")
        elif self.load_all:
            # From RL, saves snapshot: params["agent_state_dict"]
            loaded_state_dict = torch.load(
                self.state_dict_filename, map_location=torch.device("cpu")
            )
            self.load_state_dict(loaded_state_dict["agent_state_dict"])
            logger.log("Agnet loaded FULL state dict.")
        else:
            logger.log("Agent NOT loading state dict.")

        if share_memory:
            self.model.share_memory()
            self.shared_model = self.model
        if self.initial_model_state_dict is not None:
            raise NotImplementedError
        self.distribution = Categorical(dim=env_spaces.action.n)
        self.env_spaces = env_spaces
        self.share_memory = share_memory

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward), device=self.device
        )
        pi, value, rnn_state, conv = self.model(*model_inputs, self.prev_rnn_state)
        if self._act_uniform:
            pi[:] = 1.0 / pi.shape[-1]  # uniform
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
        agent_info = AgentInfoRnnConv(
            dist_info=dist_info,
            value=value,
            prev_rnn_state=prev_rnn_state,
            conv=conv if self.store_latent else None,
        )  # Don't write the extra data.
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        self.advance_rnn_state(rnn_state)
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to(
            (observation, prev_action, prev_reward), device=self.device
        )
        _pi, value, _rnn_state, _conv = self.model(*agent_inputs, self.prev_rnn_state)
        return value.to("cpu")

    def set_act_uniform(self, act_uniform=True):
        self._act_uniform = act_uniform
Esempio n. 19
0
class DiscreteSacAEAgent(BaseAgent):
    """Agent for SAC algorithm, including action-squashing, using twin Q-values."""
    def __init__(
        self,
        ModelCls=SACAEActor,  # Pi model.
        CriticCls=SACAECritic,
        EncoderCls=PixelEncoder,
        DecoderCls=PixelDecoder,
        model_kwargs={},  # Pi model.
        critic_kwargs={},
        encoder_kwargs={},
        initial_model_state_dict=None,  # All models.
        pretrain_std=0.75,  # With squash 0.75 is near uniform.
        random_actions_for_pretraining=False):
        model_kwargs["EncoderCls"] = EncoderCls
        model_kwargs["encoder_kwargs"] = encoder_kwargs
        """Saves input arguments; network defaults stored within."""
        super().__init__(ModelCls=ModelCls,
                         model_kwargs=model_kwargs,
                         initial_model_state_dict=initial_model_state_dict)
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        self.initial_model_state_dict = None  # Don't let base agent try to load.
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict

        self.critic = self.CriticCls(**self.env_model_kwargs,
                                     EncoderCls=self.EncoderCls,
                                     encoder_kwargs=self.encoder_kwargs,
                                     **self.critic_kwargs)
        self.target_model = self.CriticCls(**self.env_model_kwargs,
                                           EncoderCls=self.EncoderCls,
                                           encoder_kwargs=self.encoder_kwargs,
                                           **self.critic_kwargs)
        self.decoder = self.DecoderCls(**self.encoder_kwargs)
        # self.q1_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
        # self.q2_model = self.QModelCls(**self.env_model_kwargs, **self.q_model_kwargs)
        # self.target_q1_model = self.QModelCls(**self.env_model_kwargs,
        #     **self.q_model_kwargs)
        # self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
        #     **self.q_model_kwargs)
        self.target_model.load_state_dict(self.critic.state_dict())
        # Tie the Encoder of the actor to that of the critic
        self.model.encoder.copy_weights_from(self.critic.encoder)

        # self.target_q1_model.load_state_dict(self.q1_model.state_dict())
        # self.target_q2_model.load_state_dict(self.q2_model.state_dict())
        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        # assert len(env_spaces.action.shape) == 1
        # self.distribution = Gaussian(
        #     dim=env_spaces.action.shape[0],
        #     squash=self.action_squash,
        #     min_std=np.exp(MIN_LOG_STD),
        #     max_std=np.exp(MAX_LOG_STD),
        # )
        self.distribution = Categorical(dim=env_spaces.action.n)

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)
        # self.q1_model.to(self.device)
        # self.q2_model.to(self.device)
        # self.target_q1_model.to(self.device)
        # self.target_q2_model.to(self.device)
        self.critic.to(self.device)
        self.target_model.to(self.device)
        self.decoder.to(self.device)

    def data_parallel(self):
        super().data_parallel
        DDP_WRAP = DDPC if self.device.type == "cpu" else DDP
        # self.q1_model = DDP_WRAP(self.q1_model)
        # self.q2_model = DDP_WRAP(self.q2_model)
        self.critic = DDP_WRAP(self.critic)

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    def make_env_to_model_kwargs(self, env_spaces):
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.n,
        )

    def q(self, observation, prev_action, prev_reward, detach_encoder=False):
        """Compute twin Q-values for state/observation and input action 
        (with grad)."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        # q1 = self.q1_model(*model_inputs)
        # q2 = self.q2_model(*model_inputs)
        q1, q2, _, _ = self.critic(*model_inputs,
                                   detach_encoder=detach_encoder)
        # print("critic device", self.critic.q1[0].weight.device)
        return q1.cpu(), q2.cpu()

    def target_q(self,
                 observation,
                 prev_action,
                 prev_reward,
                 detach_encoder=False):
        """Compute twin target Q-values for state/observation and input
        action."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        # target_q1 =self.target_q1_model(*model_inputs)
        # target_q2 = self.target_q2_model(*model_inputs)
        target_q1, target_q2, _, _ = self.target_model(
            *model_inputs, detach_encoder=detach_encoder)
        return target_q1.cpu(), target_q2.cpu()

    def pi(self, observation, prev_action, prev_reward, detach_encoder=False):
        """Compute action log-probabilities for state/observation, and
        sample new action (with grad).  Uses special ``sample_loglikelihood()``
        method of Gaussian distriution, which handles action squashing
        through this process."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        # mean, log_std = self.model(*model_inputs)
        # dist_info = DistInfoStd(mean=mean, log_std=log_std)
        # action, log_pi = self.distribution.sample_loglikelihood(dist_info)
        pi, _, _ = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        log_pi = torch.log(pi)

        # TODO: potentially use argmax to determine the action instead of sampling from the distribution.
        # TODO: Figure out what to do for log_pi

        # action = self.distribution.sample(dist_info)
        # log_pi = self.distribution.log_likelihood(action, dist_info)
        log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
        return action, log_pi, dist_info  # Action stays on device for q models.

    def z(self, observation, prev_action, prev_reward, detach_encoder=False):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _, _, mu, log_sd = self.critic(*model_inputs,
                                       detach_encoder=detach_encoder)
        if not self.critic.rae:
            # Reparameterize
            sd = torch.exp(log_sd)
            eps = torch.randn_like(sd)
            z = mu + eps * sd
        else:
            z = mu
        return z, mu, log_sd

    def decode(self, z):
        z, = buffer_to((z, ), device=self.device)
        return self.decoder(z)

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        # mean, log_std = self.model(*model_inputs)
        # dist_info = DistInfoStd(mean=mean, log_std=log_std)
        # action = self.distribution.sample(dist_info)
        if self.random_actions_for_pretraining:
            action = torch.randint_like(prev_action, 15)
            action = buffer_to(action, device="cpu")
            return AgentStep(action=action,
                             agent_info=AgentInfo(dist_info=None))

        pi, _, _ = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_model, self.critic.state_dict(), tau)

    def rae(self):
        return self.critic.rae

    @property
    def models(self):
        return Models(model=self.model, critic=self.critic)

    def actor_parameters(self):
        return self.model.parameters()

    def critic_parameters(self):
        return self.critic.parameters()

    def decoder_parameters(self):
        return self.decoder.parameters()

    def encoder_parameters(self):
        return self.critic.encoder.parameters()

    def train_mode(self, itr):
        super().train_mode(itr)
        self.critic.train()

    def sample_mode(self, itr):
        super().sample_mode(itr)
        self.critic.eval()
        if itr == 0:
            logger.log(f"Agent at itr {itr}, sample std: {self.pretrain_std}")
        if itr == self.min_itr_learn:
            logger.log(f"Agent at itr {itr}, sample std: learned.")
        # std = None if itr >= self.min_itr_learn else self.pretrain_std
        # self.distribution.set_std(std)  # If None: std from policy dist_info.

    def eval_mode(self, itr):
        super().eval_mode(itr)
        self.critic.eval()
        # self.distribution.set_std(0.)  # Deterministic (dist_info std ignored).

    def state_dict(self):
        return dict(
            model=self.model.state_dict(),  # Pi model.
            critic=self.critic.state_dict(),
            target_model=self.target_model.state_dict(),
            decoder=self.decoder.state_dict())

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.critic.load_state_dict(state_dict["critic"])
        self.target_model.load_state_dict(state_dict["target_model"])
        self.decoder.load_state_dict(state_dict["decoder"])
Esempio n. 20
0
class DiscreteSacAgent(BaseAgent):
    """Agent for SAC algorithm, including action-squashing, using twin Q-values."""
    def __init__(
            self,
            ModelCls=DiscreteMlp,  # Pi model.
            QModelCls=DiscreteQModel,
            model_kwargs=None,  # Pi model.
            q_model_kwargs=None,
            v_model_kwargs=None,
            initial_model_state_dict=None,  # All models.
            pretrain_std=0.75,  # With squash 0.75 is near uniform.
    ):
        """Saves input arguments; network defaults stored within."""
        if model_kwargs is None:
            model_kwargs = dict(hidden_sizes=[256, 256])
        if q_model_kwargs is None:
            q_model_kwargs = dict(hidden_sizes=[256, 256])
        if v_model_kwargs is None:
            v_model_kwargs = dict(hidden_sizes=[256, 256])
        super().__init__(ModelCls=ModelCls,
                         model_kwargs=model_kwargs,
                         initial_model_state_dict=initial_model_state_dict)
        save__init__args(locals())
        self.min_itr_learn = 0  # Get from algo.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        _initial_model_state_dict = self.initial_model_state_dict
        self.initial_model_state_dict = None  # Don't let base agent try to load.
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.initial_model_state_dict = _initial_model_state_dict
        self.q1_model = self.QModelCls(**self.env_model_kwargs,
                                       **self.q_model_kwargs)
        self.q2_model = self.QModelCls(**self.env_model_kwargs,
                                       **self.q_model_kwargs)
        self.target_q1_model = self.QModelCls(**self.env_model_kwargs,
                                              **self.q_model_kwargs)
        self.target_q2_model = self.QModelCls(**self.env_model_kwargs,
                                              **self.q_model_kwargs)
        self.target_q1_model.load_state_dict(self.q1_model.state_dict())
        self.target_q2_model.load_state_dict(self.q2_model.state_dict())
        if self.initial_model_state_dict is not None:
            self.load_state_dict(self.initial_model_state_dict)
        # assert len(env_spaces.action.shape) == 1
        # self.distribution = Gaussian(
        #     dim=env_spaces.action.shape[0],
        #     squash=self.action_squash,
        #     min_std=np.exp(MIN_LOG_STD),
        #     max_std=np.exp(MAX_LOG_STD),
        # )
        self.distribution = Categorical(dim=env_spaces.action.n)

    def to_device(self, cuda_idx=None):
        super().to_device(cuda_idx)
        self.q1_model.to(self.device)
        self.q2_model.to(self.device)
        self.target_q1_model.to(self.device)
        self.target_q2_model.to(self.device)

    def data_parallel(self):
        super().data_parallel
        DDP_WRAP = DDPC if self.device.type == "cpu" else DDP
        self.q1_model = DDP_WRAP(self.q1_model)
        self.q2_model = DDP_WRAP(self.q2_model)

    def give_min_itr_learn(self, min_itr_learn):
        self.min_itr_learn = min_itr_learn  # From algo.

    def make_env_to_model_kwargs(self, env_spaces):
        return dict(
            observation_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.n,
        )

    def q(self, observation, prev_action, prev_reward):
        """Compute twin Q-values for state/observation and input action 
        (with grad)."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        q1 = self.q1_model(*model_inputs)
        q2 = self.q2_model(*model_inputs)
        return q1.cpu(), q2.cpu()

    def target_q(self, observation, prev_action, prev_reward):
        """Compute twin target Q-values for state/observation and input
        action."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        target_q1 = self.target_q1_model(*model_inputs)
        target_q2 = self.target_q2_model(*model_inputs)
        return target_q1.cpu(), target_q2.cpu()

    def pi(self, observation, prev_action, prev_reward):
        """Compute action log-probabilities for state/observation, and
        sample new action (with grad).  Uses special ``sample_loglikelihood()``
        method of Gaussian distriution, which handles action squashing
        through this process."""
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        # mean, log_std = self.model(*model_inputs)
        # dist_info = DistInfoStd(mean=mean, log_std=log_std)
        # action, log_pi = self.distribution.sample_loglikelihood(dist_info)
        pi = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        log_pi = torch.log(pi)

        # TODO: potentially use argmax to determine the action instead of sampling from the distribution.
        # TODO: Figure out what to do for log_pi

        # action = self.distribution.sample(dist_info)
        # log_pi = self.distribution.log_likelihood(action, dist_info)
        log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
        return action, log_pi, dist_info  # Action stays on device for q models.

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        # mean, log_std = self.model(*model_inputs)
        # dist_info = DistInfoStd(mean=mean, log_std=log_std)
        # action = self.distribution.sample(dist_info)
        pi = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    def update_target(self, tau=1):
        update_state_dict(self.target_q1_model, self.q1_model.state_dict(),
                          tau)
        update_state_dict(self.target_q2_model, self.q2_model.state_dict(),
                          tau)

    @property
    def models(self):
        return Models(pi=self.model, q1=self.q1_model, q2=self.q2_model)

    def pi_parameters(self):
        return self.model.parameters()

    def q1_parameters(self):
        return self.q1_model.parameters()

    def q2_parameters(self):
        return self.q2_model.parameters()

    def train_mode(self, itr):
        super().train_mode(itr)
        self.q1_model.train()
        self.q2_model.train()

    def sample_mode(self, itr):
        super().sample_mode(itr)
        self.q1_model.eval()
        self.q2_model.eval()
        if itr == 0:
            logger.log(f"Agent at itr {itr}, sample std: {self.pretrain_std}")
        if itr == self.min_itr_learn:
            logger.log(f"Agent at itr {itr}, sample std: learned.")
        # std = None if itr >= self.min_itr_learn else self.pretrain_std
        # self.distribution.set_std(std)  # If None: std from policy dist_info.

    def eval_mode(self, itr):
        super().eval_mode(itr)
        self.q1_model.eval()
        self.q2_model.eval()
        # self.distribution.set_std(0.)  # Deterministic (dist_info std ignored).

    def state_dict(self):
        return dict(
            model=self.model.state_dict(),  # Pi model.
            q1_model=self.q1_model.state_dict(),
            q2_model=self.q2_model.state_dict(),
            target_q1_model=self.target_q1_model.state_dict(),
            target_q2_model=self.target_q2_model.state_dict(),
        )

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.q1_model.load_state_dict(state_dict["q1_model"])
        self.q2_model.load_state_dict(state_dict["q2_model"])
        self.target_q1_model.load_state_dict(state_dict["target_q1_model"])
        self.target_q2_model.load_state_dict(state_dict["target_q2_model"])
Esempio n. 21
0
class RADPgVaeAgent(BaseAgent):
    """
    Agent for policy gradient algorithm using categorical action distribution.
    Same as ``GausssianPgAgent`` and related classes, except uses
    ``Categorical`` distribution, and has a different interface to the model
    (model here outputs discrete probabilities in place of means and log_stds,
    while both output the value estimate).
    """
    def __init__(self,
                 ModelCls=None,
                 model_kwargs=None,
                 initial_model_state_dict=None,
                 data_augs="",
                 vae_loss_type="l2",
                 vae_beta=1.0,
                 sim_loss_coef=0.1,
                 k_dim=24):
        super().__init__(ModelCls=ModelCls,
                         model_kwargs=model_kwargs,
                         initial_model_state_dict=initial_model_state_dict)
        self.data_augs = data_augs
        self.vae_loss_type = vae_loss_type
        self.vae_beta = vae_beta
        self.sim_loss_coef = sim_loss_coef
        self.k_dim = k_dim

    def aug_obs(self, observation):
        # Apply initial augmentations
        for aug, func in self.augs_funcs.items():
            if 'cutout' in aug:
                observation = func(observation)

        observation = observation.type(
            torch.float)  # Expect torch.uint8 inputs
        observation = observation.mul_(1. /
                                       255)  # From [0-255] to [0-1], in place.

        for aug, func in self.augs_funcs.items():
            if 'cutout' in aug:
                continue
            else:
                observation = func(observation)
        return observation

    def __call__(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        # This is what needs to modified to apply the augmentation from the data.
        assert len(
            observation.shape) == 4, "Observation shape was not length 4"
        observation_one = self.aug_obs(observation)
        observation_two = self.aug_obs(observation.detach().clone())

        # if hasattr(self.model, "final_act") and self.model.final_act == "tanh":
        #     observation_one = 2*observation_one - 1
        #     observation_two = 2*observation_two - 1

        observation_one, observation_two, prev_action, prev_reward = buffer_to(
            (observation_one, observation_two, prev_action, prev_reward),
            device=self.device)

        pi_one, value_one, latent_one, reconstruction_one = self.model(
            observation_one, prev_action, prev_reward)
        pi_two, value_two, latent_two, reconstruction_two = self.model(
            observation_two, prev_action, prev_reward)
        bs = 2 * len(observation)

        if self.vae_loss_type == "l2":
            recon_loss = (torch.sum(
                (observation_one - reconstruction_one).pow(2)) + torch.sum(
                    (observation_two - reconstruction_two).pow(2))) / bs
        # elif self.vae_loss_type == "bce":
        #     recon_loss = (torch.nn.functional.binary_cross_entropy(reconstruction, obs) +
        #                 torch.nn.functional.binary_cross_entropy(reconstruction, obs)) / 2

        # Calculate the similarity loss
        mu_one, logsd_one = torch.chunk(latent_one, 2, dim=1)
        mu_two, logsd_two = torch.chunk(latent_two, 2, dim=1)

        latent_loss_one = torch.sum(-0.5 *
                                    (1 + (2 * logsd_one) - mu_one.pow(2) -
                                     (2 * logsd_one).exp()))
        latent_loss_two = torch.sum(-0.5 *
                                    (1 + (2 * logsd_two) - mu_two.pow(2) -
                                     (2 * logsd_two).exp()))
        latent_loss = (latent_loss_one + latent_loss_two) / bs

        mu_one, mu_two = mu_one[:, :self.k_dim], mu_two[:, :self.k_dim]
        logvar_one, logvar_two = 2 * logsd_one[:, :self.
                                               k_dim], 2 * logsd_two[:, :self.
                                                                     k_dim]
        # KL divergence between original and augmented.
        sim_loss = torch.sum(logvar_two - logvar_one + 0.5 *
                             (logvar_one.exp() +
                              (mu_one - mu_two).pow(2)) / logvar_two.exp() -
                             0.5) / (bs // 2)

        vae_loss = recon_loss + self.vae_beta * latent_loss + self.sim_loss_coef * sim_loss

        pi_avg = (pi_one + pi_two) / 2

        return buffer_to(
            (DistInfo(prob=pi_avg), value_one, value_two, vae_loss),
            device="cpu")

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)

        self.augs_funcs = OrderedDict()
        aug_to_func = {
            'crop': rad.random_crop,
            'crop_horiz': rad.random_crop_horizontile,
            'grayscale': rad.random_grayscale,
            'cutout': rad.random_cutout,
            'cutout_color': rad.random_cutout_color,
            'flip': rad.random_flip,
            'rotate': rad.random_rotation,
            'rand_conv': rad.random_convolution,
            'color_jitter': rad.random_color_jitter,
            'no_aug': rad.no_aug,
        }

        if self.data_augs == "":
            aug_names = []
        else:
            aug_names = self.data_augs.split('-')
        for aug_name in aug_names:
            assert aug_name in aug_to_func, 'invalid data aug string'
            self.augs_funcs[aug_name] = aug_to_func[aug_name]

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        #observation = observation.type(torch.float)  # Expect torch.uint8 inputs
        #observation = observation.mul_(1. / 255)  # From [0-255] to [0-1], in place.
        if len(observation.shape) == 3:
            observation = self.aug_obs(observation.unsqueeze(0)).squeeze(0)
        else:
            observation = self.aug_obs(observation)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        pi, value, latent, reconstruction = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info, value=value)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        #observation = observation.type(torch.float)  # Expect torch.uint8 inputs
        #observation = observation.mul_(1. / 255)  # From [0-255] to [0-1], in place.
        if len(observation.shape) == 3:
            observation = self.aug_obs(observation.unsqueeze(0)).squeeze(0)
        else:
            observation = self.aug_obs(observation)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _pi, value, _latent, _reconstruction = self.model(*model_inputs)
        return value.to("cpu")

    @torch.no_grad()
    def reconstructions(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        observation = observation.type(torch.float)
        observation = observation.mul_(1. / 255)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _pi, _value, _latent, reconstruction = self.model(*model_inputs)
        return reconstruction.to("cpu")
Esempio n. 22
0
class RADPgAgent(BaseAgent):
    """
    Agent for policy gradient algorithm using categorical action distribution.
    Same as ``GausssianPgAgent`` and related classes, except uses
    ``Categorical`` distribution, and has a different interface to the model
    (model here outputs discrete probabilities in place of means and log_stds,
    while both output the value estimate).
    """
    def __init__(self,
                 ModelCls=None,
                 model_kwargs=None,
                 initial_model_state_dict=None,
                 data_augs="",
                 both_actions=False):
        super().__init__(ModelCls=ModelCls,
                         model_kwargs=model_kwargs,
                         initial_model_state_dict=initial_model_state_dict)
        self.data_augs = data_augs
        self.both_actions = both_actions

    def aug_obs(self, observation):
        # Apply initial augmentations
        for aug, func in self.augs_funcs.items():
            if 'cutout' in aug:
                observation = func(observation)

        observation = observation.type(
            torch.float)  # Expect torch.uint8 inputs
        observation = observation.mul_(1. /
                                       255)  # From [0-255] to [0-1], in place.

        for aug, func in self.augs_funcs.items():
            if 'cutout' in aug:
                continue
            else:
                observation = func(observation)
        return observation

    def __call__(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)

        # This is what needs to modified to apply the augmentation from the data.
        assert len(
            observation.shape) == 4, "Observation shape was not length 4"
        augmented = self.aug_obs(observation)

        observation = observation.type(
            torch.float)  # Expect torch.uint8 inputs
        observation = observation.mul_(1. /
                                       255)  # From [0-255] to [0-1], in place.

        augmented, observation, prev_action, prev_reward = buffer_to(
            (augmented, observation, prev_action, prev_reward),
            device=self.device)

        # For visualizing the observations
        # import matplotlib.pyplot as plt
        # from torchvision.utils import make_grid
        # def show_imgs(x,max_display=16):
        #     grid = make_grid(x[:max_display],4).permute(1,2,0).cpu().numpy()
        #     plt.xticks([])
        #     plt.yticks([])
        #     plt.imshow(grid)
        #     plt.show()
        # show_imgs(orig_observation)
        # show_imgs(observation)
        aug_pi, aug_value = self.model(augmented, prev_action, prev_reward)
        if self.both_actions:
            pi, value = self.model(observation, prev_action, prev_reward)
            return buffer_to(
                (DistInfo(prob=aug_pi), DistInfo(prob=pi), aug_value, value),
                device="cpu")
        return buffer_to((DistInfo(prob=aug_pi), aug_value), device="cpu")

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)

        self.augs_funcs = OrderedDict()
        aug_to_func = {
            'crop': rad.random_crop,
            'crop_horiz': rad.random_crop_horizontile,
            'grayscale': rad.random_grayscale,
            'cutout': rad.random_cutout,
            'cutout_color': rad.random_cutout_color,
            'flip': rad.random_flip,
            'rotate': rad.random_rotation,
            'rand_conv': rad.random_convolution,
            'color_jitter': rad.random_color_jitter,
            'no_aug': rad.no_aug,
        }

        if self.data_augs == "":
            aug_names = []
        else:
            aug_names = self.data_augs.split('-')
        for aug_name in aug_names:
            assert aug_name in aug_to_func, 'invalid data aug string'
            self.augs_funcs[aug_name] = aug_to_func[aug_name]

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        observation = observation.type(
            torch.float)  # Expect torch.uint8 inputs
        observation = observation.mul_(1. /
                                       255)  # From [0-255] to [0-1], in place.
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        pi, value = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info, value=value)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        observation = observation.type(
            torch.float)  # Expect torch.uint8 inputs
        observation = observation.mul_(1. /
                                       255)  # From [0-255] to [0-1], in place.
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _pi, value = self.model(*model_inputs)
        return value.to("cpu")
Esempio n. 23
0
class AtariPgAgent(BaseAgent):
    """Only doing the feedforward agent for now."""

    def __init__(
            self,
            ModelCls=AtariPgModel,
            store_latent=False,
            state_dict_filename=None,
            load_conv=False,
            load_all=False,
            **kwargs
        ):
        super().__init__(ModelCls=ModelCls, **kwargs)
        self.store_latent = store_latent
        self.state_dict_filename = state_dict_filename
        self.load_conv = load_conv
        self.load_all = load_all
        assert not (load_all and load_conv)
        self._act_uniform = False

    def __call__(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        pi, value, _ = self.model(*model_inputs)  # ignore conv output
        return buffer_to((DistInfo(prob=pi), value), device="cpu")

    def initialize(self, env_spaces, share_memory=False,
            global_B=1, env_ranks=None):
        self.model = self.ModelCls(
            image_shape=env_spaces.observation.shape,
            action_size=env_spaces.action.n,
            **self.model_kwargs
        )  # Model will have stop_grad inside it.
        if self.load_conv:
            logger.log("Agent loading state dict: " + self.state_dict_filename)
            loaded_state_dict = torch.load(self.state_dict_filename,
                map_location=torch.device('cpu'))
            # From UL, saves snapshot: params["algo_state_dict"]["encoder"]
            loaded_state_dict = loaded_state_dict.get("algo_state_dict", loaded_state_dict)
            loaded_state_dict = loaded_state_dict.get("encoder", loaded_state_dict)
            # A bit onerous, but ensures that state dicts match:
            conv_state_dict = OrderedDict([(k.replace("conv.", "", 1), v)
                for k, v in loaded_state_dict.items() if k.startswith("conv.")])
            self.model.conv.load_state_dict(conv_state_dict)
            logger.log("Agent loaded CONV state dict.")
        elif self.load_all:
            # From RL, saves snapshot: params["agent_state_dict"]
            loaded_state_dict = torch.load(self.state_dict_filename,
                map_location=torch.device('cpu'))
            self.load_state_dict(loaded_state_dict["agent_state_dict"])
            logger.log("Agnet loaded FULL state dict.")
        else:
            logger.log("Agent NOT loading state dict.")

        if share_memory:
            self.model.share_memory()
            self.shared_model = self.model
        if self.initial_model_state_dict is not None:
            raise NotImplementedError
        self.distribution = Categorical(dim=env_spaces.action.n)
        self.env_spaces = env_spaces
        self.share_memory = share_memory

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        pi, value, conv = self.model(*model_inputs)
        if self._act_uniform:
            pi[:] = 1. / pi.shape[-1]  # uniform
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfoConv(dist_info=dist_info, value=value,
            conv=conv if self.store_latent else None)  # Don't write extra data.
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
            device=self.device)
        _pi, value, _ = self.model(*model_inputs)  # Ignore conv out
        return value.to("cpu")

    def set_act_uniform(self, act_uniform=True):
        self._act_uniform = act_uniform
Esempio n. 24
0
 def initialize(self, env_spaces, share_memory=False):
     super().initialize(env_spaces, share_memory)
     self.distribution = Categorical(dim=env_spaces.action.n)
Esempio n. 25
0
class RecurrentCategoricalOCAgentBase(OCOptimizerMixin, BaseAgent):
    def __call__(self,
                 observation,
                 prev_action,
                 prev_reward,
                 sampled_option,
                 prev_option,
                 init_rnn_state,
                 device="cpu"):
        prev_action = self.distribution.to_onehot(prev_action)
        prev_option = self.distribution_omega.to_onehot_with_invalid(
            prev_option)
        model_inputs = buffer_to((observation, prev_action, prev_reward,
                                  prev_option, init_rnn_state, sampled_option),
                                 device=self.device)
        pi, beta, q, pi_omega, next_rnn_state = self.model(*model_inputs[:-1])
        return buffer_to(
            (DistInfo(prob=select_at_indexes(sampled_option, pi)), q, beta,
             DistInfo(prob=pi_omega)),
            device=device), next_rnn_state

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)
        self.distribution_omega = Categorical(
            dim=self.model_kwargs["option_size"])

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward, device="cpu"):
        prev_option_input = self._prev_option
        if prev_option_input is None:  # Hack to extract previous option
            prev_option_input = torch.full_like(prev_action, -1)
        prev_action = self.distribution.to_onehot(prev_action)
        prev_option_input = self.distribution_omega.to_onehot_with_invalid(
            prev_option_input)
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, prev_option_input),
            device=self.device)
        pi, beta, q, pi_omega, rnn_state = self.model(*model_inputs,
                                                      self.prev_rnn_state)
        dist_info_omega = DistInfo(prob=pi_omega)
        new_o, terminations = self.sample_option(
            beta, dist_info_omega)  # Sample terminations and options
        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
            rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
        dist_info = DistInfo(prob=pi)
        dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi))
        action = self.distribution.sample(dist_info_o)
        agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                    dist_info_o=dist_info_o,
                                    q=q,
                                    value=(pi_omega * q).sum(-1),
                                    termination=terminations,
                                    dist_info_omega=dist_info_omega,
                                    prev_o=self._prev_option,
                                    o=new_o,
                                    prev_rnn_state=prev_rnn_state)
        action, agent_info = buffer_to((action, agent_info), device=device)
        self.advance_oc_state(new_o)
        self.advance_rnn_state(rnn_state)
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward, device="cpu"):
        prev_option_input = self._prev_option
        if prev_option_input is None:  # Hack to extract previous option
            prev_option_input = torch.full_like(prev_action, -1)
        prev_action = self.distribution.to_onehot(prev_action)
        prev_option_input = self.distribution_omega.to_onehot_with_invalid(
            prev_option_input)
        agent_inputs = buffer_to(
            (observation, prev_action, prev_reward, prev_option_input),
            device=self.device)
        _pi, beta, q, pi_omega, _rnn_state = self.model(
            *agent_inputs, self.prev_rnn_state)
        v = (q * pi_omega).sum(
            -1
        )  # Weight q value by probability of option. Average value if terminal
        q_prev_o = select_at_indexes(self.prev_option, q)
        beta_prev_o = select_at_indexes(self.prev_option, beta)
        value = q_prev_o * (1 - beta_prev_o) + v * beta_prev_o
        return value.to(device)
Esempio n. 26
0
class CategoricalOCAgentBase(OCOptimizerMixin, BaseAgent):
    def __call__(self,
                 observation,
                 prev_action,
                 prev_reward,
                 sampled_option,
                 device="cpu"):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, sampled_option),
            device=self.device)
        pi, beta, q, pi_omega = self.model(*model_inputs[:-1])
        return buffer_to(
            (DistInfo(prob=select_at_indexes(sampled_option, pi)), q, beta,
             DistInfo(prob=pi_omega)),
            device=device)

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)
        self.distribution_omega = Categorical(
            dim=self.model_kwargs["option_size"])

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward, device="cpu"):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        pi, beta, q, pi_omega = self.model(*model_inputs)
        dist_info_omega = DistInfo(prob=pi_omega)
        new_o, terminations = self.sample_option(
            beta, dist_info_omega)  # Sample terminations and options
        dist_info = DistInfo(prob=pi)
        dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi))
        action = self.distribution.sample(dist_info_o)
        agent_info = AgentInfoOC(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi_omega * q).sum(-1),
                                 termination=terminations,
                                 dist_info_omega=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o)
        action, agent_info = buffer_to((action, agent_info), device=device)
        self.advance_oc_state(new_o)
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward, device="cpu"):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _pi, beta, q, pi_omega = self.model(*model_inputs)
        v = (q * pi_omega).sum(
            -1
        )  # Weight q value by probability of option. Average value if terminal
        q_prev_o = select_at_indexes(self.prev_option, q)
        beta_prev_o = select_at_indexes(self.prev_option, beta)
        value = q_prev_o * (1 - beta_prev_o) + v * beta_prev_o
        return value.to(device)
Esempio n. 27
0
class CategoricalPgAgent(BaseAgent):
    """
    Agent for policy gradient algorithm using categorical action distribution.
    Same as ``GaussianPgAgent`` and related classes, except uses
    ``Categorical`` distribution, and has a different interface to the model
    (model here outputs discrete probabilities in place of means and log_stds,
    while both output the value estimate).
    """
    def __call__(self, observation, prev_action, prev_reward, dual=False):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        pi, value = (self.model_int if dual else self.model)(*model_inputs)
        return buffer_to((DistInfo(prob=pi), value), device="cpu")

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   env_ranks=None,
                   **kwargs):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           env_ranks=env_ranks,
                           **kwargs)
        self.distribution = Categorical(dim=env_spaces.action.n)

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        # TODO: need to decide which action to take
        pi, value = self.model(*model_inputs)
        int_pi, int_value = self.model_int(*model_inputs)

        dist_info = DistInfo(prob=pi)

        if self.dual_model:
            pi_int, pi_int = self.model_int(*model_inputs)
            dist_int_info = DistInfo(prob=pi_int)
            if self._mode == "eval":
                action = self.distribution.sample(dist_info)
            else:
                action = self.distribution.sample(dist_int_info)
        else:
            action = self.distribution.sample(dist_info)

        if self.dual_model:
            agent_info = AgentInfoTwin(dist_info=dist_info,
                                       value=value,
                                       dist_int_info=dist_int_info,
                                       int_value=int_value)
        else:
            agent_info = AgentInfo(dist_info=dist_info, value=value)

        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward, ret_int=False):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        if ret_int:
            assert self.dual_model
            _pi, value = self.model_int(*model_inputs)
        else:
            _pi, value = self.model(*model_inputs)
        return value.to("cpu")
Esempio n. 28
0
class RecurrentCategoricalPgAgentBase(BaseAgent):
    def __call__(self, observation, prev_action, prev_reward, init_rnn_state):
        # Assume init_rnn_state already shaped: [N,B,H]
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, init_rnn_state),
            device=self.device)
        pi, value, next_rnn_state = self.model(*model_inputs)
        dist_info, value = buffer_to((DistInfo(prob=pi), value), device="cpu")
        return dist_info, value, next_rnn_state  # Leave rnn_state on device.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   obs_stats=None,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           obs_stats=obs_stats,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        pi, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        dist_info = DistInfo(prob=pi)

        if self.dual_model:
            int_pi, int_value, int_rnn_state = self.model_int(
                *agent_inputs, self.prev_int_rnn_state)
            dist_int_info = DistInfo(prob=int_pi)
            if self._mode == "eval":
                action = self.distribution.sample(dist_info)
            else:
                action = self.distribution.sample(dist_int_info)
        else:
            action = self.distribution.sample(dist_info)

        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state or buffer_func(
            rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)

        if self.dual_model:
            prev_int_rnn_state = self.prev_int_rnn_state or buffer_func(
                int_rnn_state, torch.zeros_like)
            prev_int_rnn_state = buffer_method(prev_int_rnn_state, "transpose",
                                               0, 1)
            agent_info = AgentInfoRnnTwin(
                dist_info=dist_info,
                value=value,
                prev_rnn_state=prev_rnn_state,
                dist_int_info=dist_int_info,
                int_value=int_value,
                prev_int_rnn_state=prev_int_rnn_state)
        else:
            agent_info = AgentInfoRnn(dist_info=dist_info,
                                      value=value,
                                      prev_rnn_state=prev_rnn_state)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        self.advance_rnn_state(rnn_state)  # Keep on device.
        if self.dual_model:
            self.advance_int_rnn_state(int_rnn_state)
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def curiosity_step(self, curiosity_type, *args):
        curiosity_model = self.model.module.curiosity_model if isinstance(
            self.model, torch.nn.parallel.DistributedDataParallel
        ) else self.model.curiosity_model
        curiosity_step_minibatches = self.model_kwargs[
            'curiosity_step_kwargs']['curiosity_step_minibatches']
        T, B = args[0].shape[:2]  # either observation or next_observation
        batch_size = B
        mb_size = batch_size // curiosity_step_minibatches

        if curiosity_type in {'icm', 'micm', 'disagreement'}:
            observation, next_observation, actions = args
            actions = self.distribution.to_onehot(actions)
            curiosity_agent_inputs = IcmAgentCuriosityStepInputs(
                observation=observation,
                next_observation=next_observation,
                actions=actions)
            curiosity_agent_inputs = buffer_to(curiosity_agent_inputs,
                                               device=self.device)
            agent_curiosity_info = IcmInfo()
        elif curiosity_type == 'ndigo':
            observation, prev_actions, actions = args
            actions = self.distribution.to_onehot(actions)
            prev_actions = self.distribution.to_onehot(prev_actions)
            curiosity_agent_inputs = NdigoAgentCuriosityStepInputs(
                observations=observation,
                prev_actions=prev_actions,
                actions=actions)
            curiosity_agent_inputs = buffer_to(curiosity_agent_inputs,
                                               device=self.device)
            agent_curiosity_info = NdigoInfo(prev_gru_state=None)
        elif curiosity_type == 'rnd':
            next_observation, done = args
            curiosity_agent_inputs = RndAgentCuriosityStepInputs(
                next_observation=next_observation, done=done)
            curiosity_agent_inputs = buffer_to(curiosity_agent_inputs,
                                               device=self.device)
            agent_curiosity_info = RndInfo()

        # Need to split the intrinsic reward predictions to several minibatches -- otherwise, we will run out of GPU memory
        r_ints = []
        for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=False):
            T_idxs = slice(None)
            B_idxs = idxs
            mb_r_int = curiosity_model.compute_bonus(
                *curiosity_agent_inputs[slice(None), B_idxs])
            r_ints.append(mb_r_int)
        r_int = torch.cat(r_ints, dim=1)

        r_int, agent_curiosity_info = buffer_to((r_int, agent_curiosity_info),
                                                device="cpu")

        return AgentCuriosityStep(r_int=r_int,
                                  agent_curiosity_info=agent_curiosity_info)

    def curiosity_loss(self, curiosity_type, *args):

        curiosity_model = self.model.module.curiosity_model if isinstance(
            self.model, torch.nn.parallel.DistributedDataParallel
        ) else self.model.curiosity_model
        if curiosity_type in {'icm', 'micm'}:
            observation, next_observation, actions, valid = args
            actions = self.distribution.to_onehot(actions)
            actions = actions.squeeze(
            )  # ([batch, 1, size]) -> ([batch, size])
            curiosity_agent_inputs = buffer_to(
                (observation, next_observation, actions, valid),
                device=self.device)
            inv_loss, forward_loss = curiosity_model.compute_loss(
                *curiosity_agent_inputs)
            # inv_loss, forward_loss = curiosity_model.compute_loss(*args)
            losses = (inv_loss.to("cpu"), forward_loss.to("cpu"))
        elif curiosity_type == 'disagreement':
            observation, next_observation, actions, valid = args
            actions = self.distribution.to_onehot(actions)
            actions = actions.squeeze(
            )  # ([batch, 1, size]) -> ([batch, size])
            curiosity_agent_inputs = buffer_to(
                (observation, next_observation, actions, valid),
                device=self.device)
            forward_loss = curiosity_model.compute_loss(
                *curiosity_agent_inputs)
            losses = (forward_loss.to("cpu"))
        elif curiosity_type == 'ndigo':
            observations, prev_actions, actions, valid = args
            actions = self.distribution.to_onehot(actions)
            prev_actions = self.distribution.to_onehot(prev_actions)
            actions = actions.squeeze(
            )  # ([batch, 1, size]) -> ([batch, size])
            prev_actions = prev_actions.squeeze(
            )  # ([batch, 1, size]) -> ([batch, size])
            curiosity_agent_inputs = buffer_to(
                (observations, prev_actions, actions, valid),
                device=self.device)
            forward_loss = curiosity_model.compute_loss(
                *curiosity_agent_inputs)
            losses = (forward_loss.to("cpu"))
        elif curiosity_type == 'rnd':
            next_observation, valid = args
            curiosity_agent_inputs = buffer_to((next_observation, valid),
                                               device=self.device)
            forward_loss = curiosity_model.compute_loss(
                *curiosity_agent_inputs)
            losses = (forward_loss.to("cpu"))

        return losses

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward, ret_int=False):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        # _pi, value, _rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        if ret_int:
            assert self.dual_model
            _pi, value, _rnn_state = self.model_int(*agent_inputs,
                                                    self.prev_int_rnn_state)
        else:
            _pi, value, _rnn_state = self.model(*agent_inputs,
                                                self.prev_rnn_state)
        return value.to("cpu")
class RecurrentCategoricalPgAgentBase(BaseAgent):
    def __call__(self, observation, prev_action, prev_reward, init_rnn_state):
        # Assume init_rnn_state already shaped: [N,B,H]
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to(
            (observation, prev_action, prev_reward, init_rnn_state),
            device=self.device)
        pi, value, next_rnn_state = self.model(*model_inputs)
        dist_info, value = buffer_to((DistInfo(prob=pi), value), device="cpu")
        return dist_info, value, next_rnn_state  # Leave rnn_state on device.

    def initialize(self,
                   env_spaces,
                   share_memory=False,
                   global_B=1,
                   obs_stats=None,
                   env_ranks=None):
        super().initialize(env_spaces,
                           share_memory,
                           global_B=global_B,
                           obs_stats=obs_stats,
                           env_ranks=env_ranks)
        self.distribution = Categorical(dim=env_spaces.action.n)

    @torch.no_grad()
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        pi, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state or buffer_func(
            rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
        agent_info = AgentInfoRnn(dist_info=dist_info,
                                  value=value,
                                  prev_rnn_state=prev_rnn_state)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        self.advance_rnn_state(rnn_state)  # Keep on device.
        return AgentStep(action=action, agent_info=agent_info)

    @torch.no_grad()
    def curiosity_step(self, curiosity_type, *args):

        if curiosity_type == 'icm' or curiosity_type == 'disagreement':
            observation, next_observation, actions = args
            actions = self.distribution.to_onehot(actions)
            curiosity_agent_inputs = buffer_to(
                (observation, next_observation, actions), device=self.device)
            agent_curiosity_info = IcmInfo()
        elif curiosity_type == 'ndigo':
            observation, prev_actions, actions = args
            actions = self.distribution.to_onehot(actions)
            prev_actions = self.distribution.to_onehot(prev_actions)
            curiosity_agent_inputs = buffer_to(
                (observation, prev_actions, actions), device=self.device)
            agent_curiosity_info = NdigoInfo(prev_gru_state=None)
        elif curiosity_type == 'rnd':
            next_observation, done = args
            curiosity_agent_inputs = buffer_to((next_observation, done),
                                               device=self.device)
            agent_curiosity_info = RndInfo()

        r_int = self.model.curiosity_model.compute_bonus(
            *curiosity_agent_inputs)
        r_int, agent_curiosity_info = buffer_to((r_int, agent_curiosity_info),
                                                device="cpu")
        return AgentCuriosityStep(r_int=r_int,
                                  agent_curiosity_info=agent_curiosity_info)

    def curiosity_loss(self, curiosity_type, *args):

        if curiosity_type == 'icm' or curiosity_type == 'disagreement':
            observation, next_observation, actions, valid = args
            actions = self.distribution.to_onehot(actions)
            actions = actions.squeeze(
            )  # ([batch, 1, size]) -> ([batch, size])
            curiosity_agent_inputs = buffer_to(
                (observation, next_observation, actions, valid),
                device=self.device)
            inv_loss, forward_loss = self.model.curiosity_model.compute_loss(
                *curiosity_agent_inputs)
            losses = (inv_loss.to("cpu"), forward_loss.to("cpu"))
        elif curiosity_type == 'ndigo':
            observations, prev_actions, actions, valid = args
            actions = self.distribution.to_onehot(actions)
            prev_actions = self.distribution.to_onehot(prev_actions)
            actions = actions.squeeze(
            )  # ([batch, 1, size]) -> ([batch, size])
            prev_actions = prev_actions.squeeze(
            )  # ([batch, 1, size]) -> ([batch, size])
            curiosity_agent_inputs = buffer_to(
                (observations, prev_actions, actions, valid),
                device=self.device)
            forward_loss = self.model.curiosity_model.compute_loss(
                *curiosity_agent_inputs)
            losses = (forward_loss.to("cpu"))
        elif curiosity_type == 'rnd':
            next_observation, valid = args
            curiosity_agent_inputs = buffer_to((next_observation, valid),
                                               device=self.device)
            forward_loss = self.model.curiosity_model.compute_loss(
                *curiosity_agent_inputs)
            losses = (forward_loss.to("cpu"))

        return losses

    @torch.no_grad()
    def value(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _pi, value, _rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        return value.to("cpu")
Esempio n. 30
0
class VAE(BaseUlAlgorithm):
    """VAE to predict o_t+k from o_t."""

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
            self,
            batch_size,
            learning_rate,
            replay_filepath,
            delta_T=0,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            initial_state_dict=None,
            clip_grad_norm=1000.,
            EncoderCls=EncoderModel,
            encoder_kwargs=None,
            latent_size=128,
            ReplayCls=UlForRlReplayBuffer,
            activation_loss_coefficient=0.0,
            learning_rate_anneal=None,  # cosine
            learning_rate_warmup=0,  # number of updates
            VaeHeadCls=VaeHeadModel,
            hidden_sizes=None,  # But maybe use for forward prediction
            DecoderCls=VaeDecoderModel,
            decoder_kwargs=None,
            kl_coeff=1.,
            onehot_action=True,
            validation_split=0.0,
            n_validation_batches=0,
            ):
        optim_kwargs = dict() if optim_kwargs is None else optim_kwargs
        encoder_kwargs = dict() if encoder_kwargs is None else encoder_kwargs
        decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs
        save__init__args(locals())
        self.c_e_loss = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
        assert learning_rate_anneal in [None, "cosine"]
        self._replay_T = delta_T + 1

    def initialize(self, n_updates, cuda_idx=None):
        self.device = torch.device("cpu") if cuda_idx is None else torch.device(
            "cuda", index=cuda_idx)

        examples = self.load_replay()
        self.encoder = self.EncoderCls(
            image_shape=examples.observation.shape,
            latent_size=self.latent_size,  # UNUSED
            **self.encoder_kwargs
        )
        if self.onehot_action:
            act_dim = self.replay_buffer.samples.action.max() + 1  # discrete only
            self.distribution = Categorical(act_dim)
        else:
            act_shape = self.replay_buffer.samples.action.shape[2:]
            assert len(act_shape) == 1
            act_dim = act_shape[0]
        self.vae_head = self.VaeHeadCls(
            latent_size=self.latent_size,
            action_size=act_dim * self.delta_T,
            hidden_sizes=self.hidden_sizes,
        )
        self.decoder = self.DecoderCls(
            latent_size=self.latent_size,
            **self.decoder_kwargs
        )
        self.encoder.to(self.device)
        self.vae_head.to(self.device)
        self.decoder.to(self.device)

        self.optim_initialize(n_updates)

        if self.initial_state_dict is not None:
            self.load_state_dict(self.initial_state_dict)

    def optimize(self, itr):
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        samples = self.replay_buffer.sample_batch(self.batch_size)
        if self.lr_scheduler is not None:
            self.lr_scheduler.step(itr)  # Do every itr instead of every epoch
        self.optimizer.zero_grad()
        recon_loss, kl_loss, conv_output = self.vae_loss(samples)
        act_loss = self.activation_loss(conv_output)
        loss = recon_loss + kl_loss + act_loss
        loss.backward()
        if self.clip_grad_norm is None:
            grad_norm = 0.
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                self.parameters(), self.clip_grad_norm)
        self.optimizer.step()
        opt_info.reconLoss.append(recon_loss.item())
        opt_info.klLoss.append(kl_loss.item())
        opt_info.activationLoss.append(act_loss.item())
        opt_info.gradNorm.append(grad_norm.item())
        opt_info.convActivation.append(
            conv_output[0].detach().cpu().view(-1).numpy())  # Keep 1 full one.
        return opt_info

    def vae_loss(self, samples):
        observation = samples.observation[0]  # [T,B,C,H,W]->[B,C,H,W]
        target_observation = samples.observation[self.delta_T]
        if self.delta_T > 0:
            action = samples.action[:-1]  # [T-1,B,A]  don't need the last one
            if self.onehot_action:
                action = self.distribution.to_onehot(action)
            t, b = action.shape[:2]
            action = action.transpose(1, 0)  # [B,T-1,A]
            action = action.reshape(b, -1)
        else:
            action = None
        observation, target_observation, action = buffer_to(
            (observation, target_observation, action),
            device=self.device
        )

        h, conv_out = self.encoder(observation)
        z, mu, logvar = self.vae_head(h, action)
        recon_z = self.decoder(z)

        if target_observation.dtype == torch.uint8:
            target_observation = target_observation.type(torch.float)
            target_observation = target_observation.mul_(1 / 255.)

        b, c, h, w = target_observation.shape
        recon_losses = F.binary_cross_entropy(
            input=recon_z.reshape(b * c, h, w),
            target=target_observation.reshape(b * c, h, w),
            reduction="none",
        )
        if self.delta_T > 0:
            valid = valid_from_done(samples.done).type(torch.bool)  # [T,B]
            valid = valid[-1]  # [B]
            valid = valid.to(self.device)
        else:
            valid = None  # all are valid
        recon_losses = recon_losses.view(b, c, h, w).sum(dim=(2, 3))  # sum over H,W
        recon_losses = recon_losses.mean(dim=1)  # mean over C (o/w loss is HUGE)
        recon_loss = valid_mean(recon_losses, valid=valid)  # mean over batch

        kl_losses = 1 + logvar - mu.pow(2) - logvar.exp()
        kl_losses = kl_losses.sum(dim=-1)  # sum over latent dimension
        kl_loss = -0.5 * valid_mean(kl_losses, valid=valid)  # mean over batch
        kl_loss = self.kl_coeff * kl_loss

        return recon_loss, kl_loss, conv_out

    def validation(self, itr):
        logger.log("Computing validation loss...")
        val_info = ValInfo(*([] for _ in range(len(ValInfo._fields))))
        self.optimizer.zero_grad()
        for _ in range(self.n_validation_batches):
            samples = self.replay_buffer.sample_batch(self.batch_size,
                validation=True)
            with torch.no_grad():
                recon_loss, kl_loss, conv_output = self.vae_loss(samples)
            val_info.reconLoss.append(recon_loss.item())
            val_info.klLoss.append(kl_loss.item())
            val_info.convActivation.append(
                conv_output[0].detach().cpu().view(-1).numpy())  # Keep 1 full one.
        self.optimizer.zero_grad()
        logger.log("...validation loss completed.")
        return val_info

    def state_dict(self):
        return dict(
            encoder=self.encoder.state_dict(),
            vae_head=self.vae_head.state_dict(),
            decoder=self.decoder.state_dict(),
            optimizer=self.optimizer.state_dict(),
        )

    def load_state_dict(self, state_dict):
        self.encoder.load_state_dict(state_dict["encoder"])
        self.vae_head.load_state_dict(state_dict["vae_head"])
        self.decoder.load_state_dict(state_dict["decoder"])
        self.optimizer.load_state_dict(state_dict["optimizer"])

    def parameters(self):
        yield from self.encoder.parameters()
        yield from self.vae_head.parameters()
        yield from self.decoder.parameters()

    def named_parameters(self):
        """To allow filtering by name in weight decay."""
        yield from self.encoder.named_parameters()
        yield from self.vae_head.named_parameters()
        yield from self.decoder.named_parameters()

    def eval(self):
        self.encoder.eval()  # in case of batch norm
        self.vae_head.eval()
        self.decoder.eval()

    def train(self):
        self.encoder.train()
        self.vae_head.train()
        self.decoder.train()