class RNDNetwork(torch.nn.Module):
    EPSILON = 1e-10

    def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
        super().__init__()
        state_encoder_settings = NetworkSettings(
            normalize=True,
            hidden_units=settings.encoding_size,
            num_layers=3,
            vis_encode_type=EncoderType.SIMPLE,
            memory=None,
        )
        self._encoder = NetworkBody(specs.observation_shapes,
                                    state_encoder_settings)

    def forward(self, mini_batch: AgentBuffer) -> torch.Tensor:
        n_vis = len(self._encoder.visual_processors)
        hidden, _ = self._encoder.forward(
            vec_inputs=[
                ModelUtils.list_to_tensor(mini_batch["vector_obs"],
                                          dtype=torch.float)
            ],
            vis_inputs=[
                ModelUtils.list_to_tensor(mini_batch["visual_obs%d" % i],
                                          dtype=torch.float)
                for i in range(n_vis)
            ],
        )
        self._encoder.update_normalization(
            torch.tensor(mini_batch["vector_obs"]))
        return hidden
Ejemplo n.º 2
0
class RNDNetwork(torch.nn.Module):
    EPSILON = 1e-10

    def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
        super().__init__()
        state_encoder_settings = NetworkSettings(
            normalize=True,
            hidden_units=settings.encoding_size,
            num_layers=3,
            vis_encode_type=EncoderType.SIMPLE,
            memory=None,
        )
        self._encoder = NetworkBody(specs.observation_specs, state_encoder_settings)

    def forward(self, mini_batch: AgentBuffer) -> torch.Tensor:
        n_obs = len(self._encoder.processors)
        np_obs = ObsUtil.from_buffer(mini_batch, n_obs)
        # Convert to tensors
        tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]

        hidden, _ = self._encoder.forward(tensor_obs)
        self._encoder.update_normalization(mini_batch)
        return hidden
Ejemplo n.º 3
0
class RNDNetwork(torch.nn.Module):
    EPSILON = 1e-10

    def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
        super().__init__()
        state_encoder_settings = settings.network_settings
        if state_encoder_settings.memory is not None:
            state_encoder_settings.memory = None
            logger.warning(
                "memory was specified in network_settings but is not supported by RND. It is being ignored."
            )

        self._encoder = NetworkBody(specs.observation_specs,
                                    state_encoder_settings)

    def forward(self, mini_batch: AgentBuffer) -> torch.Tensor:
        n_obs = len(self._encoder.processors)
        np_obs = ObsUtil.from_buffer(mini_batch, n_obs)
        # Convert to tensors
        tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]

        hidden, _ = self._encoder.forward(tensor_obs)
        self._encoder.update_normalization(mini_batch)
        return hidden
class CuriosityNetwork(torch.nn.Module):
    EPSILON = 1e-10

    def __init__(self, specs: BehaviorSpec,
                 settings: CuriositySettings) -> None:
        super().__init__()
        self._policy_specs = specs
        state_encoder_settings = NetworkSettings(
            normalize=False,
            hidden_units=settings.encoding_size,
            num_layers=2,
            vis_encode_type=EncoderType.SIMPLE,
            memory=None,
        )
        self._state_encoder = NetworkBody(specs.observation_shapes,
                                          state_encoder_settings)

        self._action_flattener = ModelUtils.ActionFlattener(specs)

        self.inverse_model_action_prediction = torch.nn.Sequential(
            LinearEncoder(2 * settings.encoding_size, 1, 256),
            linear_layer(256, self._action_flattener.flattened_size),
        )

        self.forward_model_next_state_prediction = torch.nn.Sequential(
            LinearEncoder(
                settings.encoding_size + self._action_flattener.flattened_size,
                1, 256),
            linear_layer(256, settings.encoding_size),
        )

    def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Extracts the current state embedding from a mini_batch.
        """
        n_vis = len(self._state_encoder.visual_processors)
        hidden, _ = self._state_encoder.forward(
            vec_inputs=[
                ModelUtils.list_to_tensor(mini_batch["vector_obs"],
                                          dtype=torch.float)
            ],
            vis_inputs=[
                ModelUtils.list_to_tensor(mini_batch["visual_obs%d" % i],
                                          dtype=torch.float)
                for i in range(n_vis)
            ],
        )
        return hidden

    def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Extracts the next state embedding from a mini_batch.
        """
        n_vis = len(self._state_encoder.visual_processors)
        hidden, _ = self._state_encoder.forward(
            vec_inputs=[
                ModelUtils.list_to_tensor(mini_batch["next_vector_in"],
                                          dtype=torch.float)
            ],
            vis_inputs=[
                ModelUtils.list_to_tensor(mini_batch["next_visual_obs%d" % i],
                                          dtype=torch.float)
                for i in range(n_vis)
            ],
        )
        return hidden

    def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        In the continuous case, returns the predicted action.
        In the discrete case, returns the logits.
        """
        inverse_model_input = torch.cat((self.get_current_state(mini_batch),
                                         self.get_next_state(mini_batch)),
                                        dim=1)
        hidden = self.inverse_model_action_prediction(inverse_model_input)
        if self._policy_specs.is_action_continuous():
            return hidden
        else:
            branches = ModelUtils.break_into_branches(
                hidden, self._policy_specs.discrete_action_branches)
            branches = [torch.softmax(b, dim=1) for b in branches]
            return torch.cat(branches, dim=1)

    def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Uses the current state embedding and the action of the mini_batch to predict
        the next state embedding.
        """
        if self._policy_specs.is_action_continuous():
            action = ModelUtils.list_to_tensor(mini_batch["actions"],
                                               dtype=torch.float)
        else:
            action = torch.cat(
                ModelUtils.actions_to_onehot(
                    ModelUtils.list_to_tensor(mini_batch["actions"],
                                              dtype=torch.long),
                    self._policy_specs.discrete_action_branches,
                ),
                dim=1,
            )
        forward_model_input = torch.cat(
            (self.get_current_state(mini_batch), action), dim=1)

        return self.forward_model_next_state_prediction(forward_model_input)

    def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Computes the inverse loss for a mini_batch. Corresponds to the error on the
        action prediction (given the current and next state).
        """
        predicted_action = self.predict_action(mini_batch)
        if self._policy_specs.is_action_continuous():
            sq_difference = (ModelUtils.list_to_tensor(mini_batch["actions"],
                                                       dtype=torch.float) -
                             predicted_action)**2
            sq_difference = torch.sum(sq_difference, dim=1)
            return torch.mean(
                ModelUtils.dynamic_partition(
                    sq_difference,
                    ModelUtils.list_to_tensor(mini_batch["masks"],
                                              dtype=torch.float),
                    2,
                )[1])
        else:
            true_action = torch.cat(
                ModelUtils.actions_to_onehot(
                    ModelUtils.list_to_tensor(mini_batch["actions"],
                                              dtype=torch.long),
                    self._policy_specs.discrete_action_branches,
                ),
                dim=1,
            )
            cross_entropy = torch.sum(
                -torch.log(predicted_action + self.EPSILON) * true_action,
                dim=1)
            return torch.mean(
                ModelUtils.dynamic_partition(
                    cross_entropy,
                    ModelUtils.list_to_tensor(
                        mini_batch["masks"],
                        dtype=torch.float),  # use masks not action_masks
                    2,
                )[1])

    def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Calculates the curiosity reward for the mini_batch. Corresponds to the error
        between the predicted and actual next state.
        """
        predicted_next_state = self.predict_next_state(mini_batch)
        target = self.get_next_state(mini_batch)
        sq_difference = 0.5 * (target - predicted_next_state)**2
        sq_difference = torch.sum(sq_difference, dim=1)
        return sq_difference

    def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Computes the loss for the next state prediction
        """
        return torch.mean(
            ModelUtils.dynamic_partition(
                self.compute_reward(mini_batch),
                ModelUtils.list_to_tensor(mini_batch["masks"],
                                          dtype=torch.float),
                2,
            )[1])
Ejemplo n.º 5
0
class DiscriminatorNetwork(torch.nn.Module):
    gradient_penalty_weight = 10.0
    z_size = 128
    alpha = 0.0005
    mutual_information = 0.5
    EPSILON = 1e-7
    initial_beta = 0.0

    def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
        super().__init__()
        self._policy_specs = specs
        self._use_vail = settings.use_vail
        self._settings = settings

        state_encoder_settings = NetworkSettings(
            normalize=False,
            hidden_units=settings.encoding_size,
            num_layers=2,
            vis_encode_type=EncoderType.SIMPLE,
            memory=None,
        )
        self._state_encoder = NetworkBody(specs.observation_shapes,
                                          state_encoder_settings)

        self._action_flattener = ModelUtils.ActionFlattener(specs)

        encoder_input_size = settings.encoding_size
        if settings.use_actions:
            encoder_input_size += (self._action_flattener.flattened_size + 1
                                   )  # + 1 is for done

        self.encoder = torch.nn.Sequential(
            linear_layer(encoder_input_size, settings.encoding_size),
            Swish(),
            linear_layer(settings.encoding_size, settings.encoding_size),
            Swish(),
        )

        estimator_input_size = settings.encoding_size
        if settings.use_vail:
            estimator_input_size = self.z_size
            self._z_sigma = torch.nn.Parameter(torch.ones((self.z_size),
                                                          dtype=torch.float),
                                               requires_grad=True)
            self._z_mu_layer = linear_layer(
                settings.encoding_size,
                self.z_size,
                kernel_init=Initialization.KaimingHeNormal,
                kernel_gain=0.1,
            )
            self._beta = torch.nn.Parameter(torch.tensor(self.initial_beta,
                                                         dtype=torch.float),
                                            requires_grad=False)

        self._estimator = torch.nn.Sequential(
            linear_layer(estimator_input_size, 1), torch.nn.Sigmoid())

    def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Creates the action Tensor. In continuous case, corresponds to the action. In
        the discrete case, corresponds to the concatenation of one hot action Tensors.
        """
        return self._action_flattener.forward(
            torch.as_tensor(mini_batch["actions"], dtype=torch.float))

    def get_state_encoding(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Creates the observation input.
        """
        n_vis = len(self._state_encoder.visual_encoders)
        hidden, _ = self._state_encoder.forward(
            vec_inputs=[
                torch.as_tensor(mini_batch["vector_obs"], dtype=torch.float)
            ],
            vis_inputs=[
                torch.as_tensor(mini_batch["visual_obs%d" % i],
                                dtype=torch.float) for i in range(n_vis)
            ],
        )
        return hidden

    def compute_estimate(self,
                         mini_batch: AgentBuffer,
                         use_vail_noise: bool = False) -> torch.Tensor:
        """
        Given a mini_batch, computes the estimate (How much the discriminator believes
        the data was sampled from the demonstration data).
        :param mini_batch: The AgentBuffer of data
        :param use_vail_noise: Only when using VAIL : If true, will sample the code, if
        false, will return the mean of the code.
        """
        encoder_input = self.get_state_encoding(mini_batch)
        if self._settings.use_actions:
            actions = self.get_action_input(mini_batch)
            dones = torch.as_tensor(mini_batch["done"], dtype=torch.float)
            encoder_input = torch.cat([encoder_input, actions, dones], dim=1)
        hidden = self.encoder(encoder_input)
        z_mu: Optional[torch.Tensor] = None
        if self._settings.use_vail:
            z_mu = self._z_mu_layer(hidden)
            hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
        estimate = self._estimator(hidden)
        return estimate, z_mu

    def compute_loss(self, policy_batch: AgentBuffer,
                     expert_batch: AgentBuffer) -> torch.Tensor:
        """
        Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
        """
        total_loss = torch.zeros(1)
        stats_dict: Dict[str, np.ndarray] = {}
        policy_estimate, policy_mu = self.compute_estimate(policy_batch,
                                                           use_vail_noise=True)
        expert_estimate, expert_mu = self.compute_estimate(expert_batch,
                                                           use_vail_noise=True)
        stats_dict["Policy/GAIL Policy Estimate"] = (
            policy_estimate.mean().detach().cpu().numpy())
        stats_dict["Policy/GAIL Expert Estimate"] = (
            expert_estimate.mean().detach().cpu().numpy())
        discriminator_loss = -(
            torch.log(expert_estimate + self.EPSILON) +
            torch.log(1.0 - policy_estimate + self.EPSILON)).mean()
        stats_dict["Losses/GAIL Loss"] = discriminator_loss.detach().cpu(
        ).numpy()
        total_loss += discriminator_loss
        if self._settings.use_vail:
            # KL divergence loss (encourage latent representation to be normal)
            kl_loss = torch.mean(-torch.sum(
                1 + (self._z_sigma**2).log() - 0.5 * expert_mu**2 -
                0.5 * policy_mu**2 - (self._z_sigma**2),
                dim=1,
            ))
            vail_loss = self._beta * (kl_loss - self.mutual_information)
            with torch.no_grad():
                self._beta.data = torch.max(
                    self._beta + self.alpha *
                    (kl_loss - self.mutual_information),
                    torch.tensor(0.0),
                )
            total_loss += vail_loss
            stats_dict["Policy/GAIL Beta"] = self._beta.detach().cpu().numpy()
            stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy()
        if self.gradient_penalty_weight > 0.0:
            total_loss += (
                self.gradient_penalty_weight *
                self.compute_gradient_magnitude(policy_batch, expert_batch))
        return total_loss, stats_dict

    def compute_gradient_magnitude(self, policy_batch: AgentBuffer,
                                   expert_batch: AgentBuffer) -> torch.Tensor:
        """
        Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
        for off-policy. Compute gradients w.r.t randomly interpolated input.
        """
        policy_obs = self.get_state_encoding(policy_batch)
        expert_obs = self.get_state_encoding(expert_batch)
        obs_epsilon = torch.rand(policy_obs.shape)
        encoder_input = obs_epsilon * policy_obs + (1 -
                                                    obs_epsilon) * expert_obs
        if self._settings.use_actions:
            policy_action = self.get_action_input(policy_batch)
            expert_action = self.get_action_input(policy_batch)
            action_epsilon = torch.rand(policy_action.shape)
            policy_dones = torch.as_tensor(policy_batch["done"],
                                           dtype=torch.float)
            expert_dones = torch.as_tensor(expert_batch["done"],
                                           dtype=torch.float)
            dones_epsilon = torch.rand(policy_dones.shape)
            encoder_input = torch.cat(
                [
                    encoder_input,
                    action_epsilon * policy_action +
                    (1 - action_epsilon) * expert_action,
                    dones_epsilon * policy_dones +
                    (1 - dones_epsilon) * expert_dones,
                ],
                dim=1,
            )
        hidden = self.encoder(encoder_input)
        if self._settings.use_vail:
            use_vail_noise = True
            z_mu = self._z_mu_layer(hidden)
            hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
        hidden = self._estimator(hidden)
        estimate = torch.mean(torch.sum(hidden, dim=1))
        gradient = torch.autograd.grad(estimate, encoder_input)[0]
        # Norm's gradient could be NaN at 0. Use our own safe_norm
        safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
        gradient_mag = torch.mean((safe_norm - 1)**2)
        return gradient_mag
class CuriosityNetwork(torch.nn.Module):
    EPSILON = 1e-10

    def __init__(self, specs: BehaviorSpec,
                 settings: CuriositySettings) -> None:
        super().__init__()
        self._action_spec = specs.action_spec
        state_encoder_settings = NetworkSettings(
            normalize=False,
            hidden_units=settings.encoding_size,
            num_layers=2,
            vis_encode_type=EncoderType.SIMPLE,
            memory=None,
        )
        self._state_encoder = NetworkBody(specs.observation_specs,
                                          state_encoder_settings)

        self._action_flattener = ActionFlattener(self._action_spec)

        self.inverse_model_action_encoding = torch.nn.Sequential(
            LinearEncoder(2 * settings.encoding_size, 1, 256))

        if self._action_spec.continuous_size > 0:
            self.continuous_action_prediction = linear_layer(
                256, self._action_spec.continuous_size)
        if self._action_spec.discrete_size > 0:
            self.discrete_action_prediction = linear_layer(
                256, sum(self._action_spec.discrete_branches))

        self.forward_model_next_state_prediction = torch.nn.Sequential(
            LinearEncoder(
                settings.encoding_size + self._action_flattener.flattened_size,
                1, 256),
            linear_layer(256, settings.encoding_size),
        )

    def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Extracts the current state embedding from a mini_batch.
        """
        n_obs = len(self._state_encoder.processors)
        np_obs = ObsUtil.from_buffer(mini_batch, n_obs)
        # Convert to tensors
        tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]

        hidden, _ = self._state_encoder.forward(tensor_obs)
        return hidden

    def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Extracts the next state embedding from a mini_batch.
        """
        n_obs = len(self._state_encoder.processors)
        np_obs = ObsUtil.from_buffer_next(mini_batch, n_obs)
        # Convert to tensors
        tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]

        hidden, _ = self._state_encoder.forward(tensor_obs)
        return hidden

    def predict_action(self, mini_batch: AgentBuffer) -> ActionPredictionTuple:
        """
        In the continuous case, returns the predicted action.
        In the discrete case, returns the logits.
        """
        inverse_model_input = torch.cat((self.get_current_state(mini_batch),
                                         self.get_next_state(mini_batch)),
                                        dim=1)

        continuous_pred = None
        discrete_pred = None
        hidden = self.inverse_model_action_encoding(inverse_model_input)
        if self._action_spec.continuous_size > 0:
            continuous_pred = self.continuous_action_prediction(hidden)
        if self._action_spec.discrete_size > 0:
            raw_discrete_pred = self.discrete_action_prediction(hidden)
            branches = ModelUtils.break_into_branches(
                raw_discrete_pred, self._action_spec.discrete_branches)
            branches = [torch.softmax(b, dim=1) for b in branches]
            discrete_pred = torch.cat(branches, dim=1)
        return ActionPredictionTuple(continuous_pred, discrete_pred)

    def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Uses the current state embedding and the action of the mini_batch to predict
        the next state embedding.
        """
        actions = AgentAction.from_buffer(mini_batch)
        flattened_action = self._action_flattener.forward(actions)
        forward_model_input = torch.cat(
            (self.get_current_state(mini_batch), flattened_action), dim=1)

        return self.forward_model_next_state_prediction(forward_model_input)

    def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Computes the inverse loss for a mini_batch. Corresponds to the error on the
        action prediction (given the current and next state).
        """
        predicted_action = self.predict_action(mini_batch)
        actions = AgentAction.from_buffer(mini_batch)
        _inverse_loss = 0
        if self._action_spec.continuous_size > 0:
            sq_difference = (actions.continuous_tensor -
                             predicted_action.continuous)**2
            sq_difference = torch.sum(sq_difference, dim=1)
            _inverse_loss += torch.mean(
                ModelUtils.dynamic_partition(
                    sq_difference,
                    ModelUtils.list_to_tensor(mini_batch[BufferKey.MASKS],
                                              dtype=torch.float),
                    2,
                )[1])
        if self._action_spec.discrete_size > 0:
            true_action = torch.cat(
                ModelUtils.actions_to_onehot(
                    actions.discrete_tensor,
                    self._action_spec.discrete_branches),
                dim=1,
            )
            cross_entropy = torch.sum(
                -torch.log(predicted_action.discrete + self.EPSILON) *
                true_action,
                dim=1,
            )
            _inverse_loss += torch.mean(
                ModelUtils.dynamic_partition(
                    cross_entropy,
                    ModelUtils.list_to_tensor(
                        mini_batch[BufferKey.MASKS],
                        dtype=torch.float),  # use masks not action_masks
                    2,
                )[1])
        return _inverse_loss

    def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Calculates the curiosity reward for the mini_batch. Corresponds to the error
        between the predicted and actual next state.
        """
        predicted_next_state = self.predict_next_state(mini_batch)
        target = self.get_next_state(mini_batch)
        sq_difference = 0.5 * (target - predicted_next_state)**2
        sq_difference = torch.sum(sq_difference, dim=1)
        return sq_difference

    def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Computes the loss for the next state prediction
        """
        return torch.mean(
            ModelUtils.dynamic_partition(
                self.compute_reward(mini_batch),
                ModelUtils.list_to_tensor(mini_batch[BufferKey.MASKS],
                                          dtype=torch.float),
                2,
            )[1])