Ejemplo n.º 1
0
    def sac_entropy_loss(
        self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool
    ) -> torch.Tensor:
        if not discrete:
            with torch.no_grad():
                target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
            entropy_loss = -1 * ModelUtils.masked_mean(
                self._log_ent_coef * target_current_diff, loss_masks
            )
        else:
            with torch.no_grad():
                branched_per_action_ent = ModelUtils.break_into_branches(
                    log_probs * log_probs.exp(), self.act_size
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss = -1 * ModelUtils.masked_mean(
                torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
            )

        return entropy_loss
Ejemplo n.º 2
0
    def forward(self, inp: torch.Tensor,
                key_masks: List[torch.Tensor]) -> torch.Tensor:
        # Gather the maximum number of entities information
        mask = torch.cat(key_masks, dim=1)

        inp = self.embedding_norm(inp)
        # Feed to self attention
        query = self.fc_q(inp)  # (b, n_q, emb)
        key = self.fc_k(inp)  # (b, n_k, emb)
        value = self.fc_v(inp)  # (b, n_k, emb)

        # Only use max num if provided
        if self.max_num_ent is not None:
            num_ent = self.max_num_ent
        else:
            num_ent = inp.shape[1]
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements.")

        output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
        # Residual
        output = self.fc_out(output) + inp
        output = self.residual_norm(output)
        # Average Pooling
        numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1),
                              dim=1)
        denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
        output = numerator / denominator
        return output
Ejemplo n.º 3
0
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
    all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
    layer to mask the padding observations.
    """
    with torch.no_grad():

        if exporting_to_onnx.is_exporting():
            with warnings.catch_warnings():
                # We ignore a TracerWarning from PyTorch that warns that doing
                # shape[n].item() will cause the trace to be incorrect (the trace might
                # not generalize to other inputs)
                # We ignore this warning because we know the model will always be
                # run with inputs of the same shape
                warnings.simplefilter("ignore")
                # When exporting to ONNX, we want to transpose the entities. This is
                # because ONNX only support input in NCHW (channel first) format.
                # Barracuda also expect to get data in NCHW.
                entities = [
                    torch.transpose(obs, 2, 1).reshape(-1, obs.shape[1].item(),
                                                       obs.shape[2].item())
                    for obs in entities
                ]

        # Generate the masking tensors for each entities tensor (mask only if all zeros)
        key_masks: List[torch.Tensor] = [
            (torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities
        ]
    return key_masks
Ejemplo n.º 4
0
 def _mask_branch(self, logits: torch.Tensor,
                  mask: torch.Tensor) -> torch.Tensor:
     raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask
     normalized_probs = raw_probs / torch.sum(raw_probs,
                                              dim=-1).unsqueeze(-1)
     normalized_logits = torch.log(normalized_probs + EPSILON)
     return normalized_logits
Ejemplo n.º 5
0
 def _behavioral_cloning_loss(
     self,
     selected_actions: AgentAction,
     log_probs: ActionLogProbs,
     expert_actions: torch.Tensor,
 ) -> torch.Tensor:
     bc_loss = 0
     if self.policy.behavior_spec.action_spec.continuous_size > 0:
         bc_loss += torch.nn.functional.mse_loss(
             selected_actions.continuous_tensor, expert_actions.continuous_tensor
         )
     if self.policy.behavior_spec.action_spec.discrete_size > 0:
         one_hot_expert_actions = ModelUtils.actions_to_onehot(
             expert_actions.discrete_tensor,
             self.policy.behavior_spec.action_spec.discrete_branches,
         )
         log_prob_branches = ModelUtils.break_into_branches(
             log_probs.all_discrete_tensor,
             self.policy.behavior_spec.action_spec.discrete_branches,
         )
         bc_loss += torch.mean(
             torch.stack(
                 [
                     torch.sum(
                         -torch.nn.functional.log_softmax(log_prob_branch, dim=1)
                         * expert_actions_branch,
                         dim=1,
                     )
                     for log_prob_branch, expert_actions_branch in zip(
                         log_prob_branches, one_hot_expert_actions
                     )
                 ]
             )
         )
     return bc_loss
Ejemplo n.º 6
0
def test_predict_with_condition(num_cond_layers):
    np.random.seed(1336)
    torch.manual_seed(1336)
    input_size, goal_size, h, num_normal_layers = 10, 1, 16, 1

    conditional_enc = ConditionalEncoder(
        input_size, goal_size, h, num_normal_layers + num_cond_layers, num_cond_layers
    )
    l_layer = linear_layer(h, 1)

    optimizer = torch.optim.Adam(
        list(conditional_enc.parameters()) + list(l_layer.parameters()), lr=0.001
    )
    batch_size = 200
    for _ in range(300):
        input_tensor = torch.rand((batch_size, input_size))
        goal_tensor = (torch.rand((batch_size, goal_size)) > 0.5).float()
        # If the goal is 1: do the sum of the inputs, else, return 0
        target = torch.sum(input_tensor, dim=1, keepdim=True) * goal_tensor
        target.detach()
        prediction = l_layer(conditional_enc(input_tensor, goal_tensor))
        error = torch.mean((prediction - target) ** 2, dim=1)
        error = torch.mean(error) / 2

        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.02
Ejemplo n.º 7
0
 def sac_policy_loss(
     self,
     log_probs: torch.Tensor,
     q1p_outs: Dict[str, torch.Tensor],
     loss_masks: torch.Tensor,
     discrete: bool,
 ) -> torch.Tensor:
     _ent_coef = torch.exp(self._log_ent_coef)
     mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
     if not discrete:
         mean_q1 = mean_q1.unsqueeze(1)
         batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     else:
         action_probs = log_probs.exp()
         branched_per_action_ent = ModelUtils.break_into_branches(
             log_probs * action_probs, self.act_size
         )
         branched_q_term = ModelUtils.break_into_branches(
             mean_q1 * action_probs, self.act_size
         )
         branched_policy_loss = torch.stack(
             [
                 torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True)
                 for i, (_lp, _qt) in enumerate(
                     zip(branched_per_action_ent, branched_q_term)
                 )
             ],
             dim=1,
         )
         batch_policy_loss = torch.squeeze(branched_policy_loss)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     return policy_loss
Ejemplo n.º 8
0
    def sac_policy_loss(
        self,
        log_probs: ActionLogProbs,
        q1p_outs: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        _cont_ent_coef = _cont_ent_coef.exp()
        _disc_ent_coef = _disc_ent_coef.exp()

        mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
        batch_policy_loss = 0
        if self._action_spec.discrete_size > 0:
            disc_log_probs = log_probs.all_discrete_tensor
            disc_action_probs = disc_log_probs.exp()
            branched_per_action_ent = ModelUtils.break_into_branches(
                disc_log_probs * disc_action_probs,
                self._action_spec.discrete_branches)
            branched_q_term = ModelUtils.break_into_branches(
                mean_q1 * disc_action_probs,
                self._action_spec.discrete_branches)
            branched_policy_loss = torch.stack(
                [
                    torch.sum(
                        _disc_ent_coef[i] * _lp - _qt, dim=1, keepdim=False)
                    for i, (_lp, _qt) in enumerate(
                        zip(branched_per_action_ent, branched_q_term))
                ],
                dim=1,
            )
            batch_policy_loss += torch.sum(branched_policy_loss, dim=1)
            all_mean_q1 = torch.sum(disc_action_probs * mean_q1, dim=1)
        else:
            all_mean_q1 = mean_q1
        if self._action_spec.continuous_size > 0:
            cont_log_probs = log_probs.continuous_tensor
            batch_policy_loss += torch.mean(_cont_ent_coef * cont_log_probs -
                                            all_mean_q1.unsqueeze(1),
                                            dim=1)
        policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)

        return policy_loss
Ejemplo n.º 9
0
 def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
     with torch.no_grad():
         target = self._random_network(mini_batch)
     prediction = self._training_network(mini_batch)
     loss = torch.mean(torch.sum((prediction - target)**2, dim=1))
     self.optimizer.zero_grad()
     loss.backward()
     self.optimizer.step()
     return {"Losses/RND Loss": loss.detach().cpu().numpy()}
    def sac_entropy_loss(
        self, log_probs: ActionLogProbs, loss_masks: torch.Tensor
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        entropy_loss = 0
        if self._action_spec.discrete_size > 0:
            with torch.no_grad():
                # Break continuous into separate branch
                disc_log_probs = log_probs.all_discrete_tensor
                branched_per_action_ent = ModelUtils.break_into_branches(
                    disc_log_probs * disc_log_probs.exp(),
                    self._action_spec.discrete_branches,
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy.discrete
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss += -1 * ModelUtils.masked_mean(
                torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks
            )
        if self._action_spec.continuous_size > 0:
            with torch.no_grad():
                cont_log_probs = log_probs.continuous_tensor
                target_current_diff = torch.sum(
                    cont_log_probs + self.target_entropy.continuous, dim=1
                )
            # We update all the _cont_ent_coef as one block
            entropy_loss += -1 * ModelUtils.masked_mean(
                _cont_ent_coef * target_current_diff, loss_masks
            )

        return entropy_loss
 def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Calculates the curiosity reward for the mini_batch. Corresponds to the error
     between the predicted and actual next state.
     """
     predicted_next_state = self.predict_next_state(mini_batch)
     target = self.get_next_state(mini_batch)
     sq_difference = 0.5 * (target - predicted_next_state)**2
     sq_difference = torch.sum(sq_difference, dim=1)
     return sq_difference
Ejemplo n.º 12
0
 def forward(
     self,
     x_self: torch.Tensor,
     entities: List[torch.Tensor],
     key_masks: List[torch.Tensor],
 ) -> torch.Tensor:
     # Gather the maximum number of entities information
     if self.entities_num_max_elements is None:
         self.entities_num_max_elements = []
         for ent in entities:
             self.entities_num_max_elements.append(ent.shape[1])
     # Concatenate all observations with self
     self_and_ent: List[torch.Tensor] = []
     for num_entities, ent in zip(self.entities_num_max_elements, entities):
         expanded_self = x_self.reshape(-1, 1, self.self_size)
         # .repeat(
         #     1, num_entities, 1
         # )
         expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
         self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     # Generate the tensor that will serve as query, key and value to self attention
     qkv = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     mask = torch.cat(key_masks, dim=1)
     # Feed to self attention
     max_num_ent = sum(self.entities_num_max_elements)
     output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent,
                                max_num_ent)
     # Residual
     output = self.residual_layer(output) + qkv
     # Average Pooling
     numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1),
                           dim=1)
     denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
     output = numerator / denominator
     # Residual between x_self and the output of the module
     output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1))
     return output
Ejemplo n.º 13
0
 def forward(self, inp: torch.Tensor,
             key_masks: List[torch.Tensor]) -> torch.Tensor:
     # Gather the maximum number of entities information
     mask = torch.cat(key_masks, dim=1)
     # Feed to self attention
     query = self.fc_q(inp)  # (b, n_q, emb)
     key = self.fc_k(inp)  # (b, n_k, emb)
     value = self.fc_v(inp)  # (b, n_k, emb)
     output, _ = self.attention(query, key, value, self.max_num_ent,
                                self.max_num_ent, mask)
     # Residual
     output = self.fc_out(output) + inp
     # Average Pooling
     numerator = torch.sum(output *
                           (1 - mask).reshape(-1, self.max_num_ent, 1),
                           dim=1)
     denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
     output = numerator / denominator
     # Residual between x_self and the output of the module
     return output
Ejemplo n.º 14
0
 def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Computes the inverse loss for a mini_batch. Corresponds to the error on the
     action prediction (given the current and next state).
     """
     predicted_action = self.predict_action(mini_batch)
     actions = AgentAction.from_dict(mini_batch)
     _inverse_loss = 0
     if self._action_spec.continuous_size > 0:
         sq_difference = (
             actions.continuous_tensor - predicted_action.continuous
         ) ** 2
         sq_difference = torch.sum(sq_difference, dim=1)
         _inverse_loss += torch.mean(
             ModelUtils.dynamic_partition(
                 sq_difference,
                 ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float),
                 2,
             )[1]
         )
     if self._action_spec.discrete_size > 0:
         true_action = torch.cat(
             ModelUtils.actions_to_onehot(
                 actions.discrete_tensor, self._action_spec.discrete_branches
             ),
             dim=1,
         )
         cross_entropy = torch.sum(
             -torch.log(predicted_action.discrete + self.EPSILON) * true_action,
             dim=1,
         )
         _inverse_loss += torch.mean(
             ModelUtils.dynamic_partition(
                 cross_entropy,
                 ModelUtils.list_to_tensor(
                     mini_batch["masks"], dtype=torch.float
                 ),  # use masks not action_masks
                 2,
             )[1]
         )
     return _inverse_loss
Ejemplo n.º 15
0
 def compute_gradient_magnitude(self, policy_batch: AgentBuffer,
                                expert_batch: AgentBuffer) -> torch.Tensor:
     """
     Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
     for off-policy. Compute gradients w.r.t randomly interpolated input.
     """
     policy_inputs = self.get_state_inputs(policy_batch)
     expert_inputs = self.get_state_inputs(expert_batch)
     interp_inputs = []
     for policy_input, expert_input in zip(policy_inputs, expert_inputs):
         obs_epsilon = torch.rand(policy_input.shape)
         interp_input = obs_epsilon * policy_input + (
             1 - obs_epsilon) * expert_input
         interp_input.requires_grad = True  # For gradient calculation
         interp_inputs.append(interp_input)
     if self._settings.use_actions:
         policy_action = self.get_action_input(policy_batch)
         expert_action = self.get_action_input(expert_batch)
         action_epsilon = torch.rand(policy_action.shape)
         policy_dones = torch.as_tensor(policy_batch[BufferKey.DONE],
                                        dtype=torch.float).unsqueeze(1)
         expert_dones = torch.as_tensor(expert_batch[BufferKey.DONE],
                                        dtype=torch.float).unsqueeze(1)
         dones_epsilon = torch.rand(policy_dones.shape)
         action_inputs = torch.cat(
             [
                 action_epsilon * policy_action +
                 (1 - action_epsilon) * expert_action,
                 dones_epsilon * policy_dones +
                 (1 - dones_epsilon) * expert_dones,
             ],
             dim=1,
         )
         action_inputs.requires_grad = True
         hidden, _ = self.encoder(interp_inputs, action_inputs)
         encoder_input = tuple(interp_inputs + [action_inputs])
     else:
         hidden, _ = self.encoder(interp_inputs)
         encoder_input = tuple(interp_inputs)
     if self._settings.use_vail:
         use_vail_noise = True
         z_mu = self._z_mu_layer(hidden)
         hidden = z_mu + torch.randn_like(
             z_mu) * self._z_sigma * use_vail_noise
     estimate = self._estimator(hidden).squeeze(1).sum()
     gradient = torch.autograd.grad(estimate,
                                    encoder_input,
                                    create_graph=True)[0]
     # Norm's gradient could be NaN at 0. Use our own safe_norm
     safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
     gradient_mag = torch.mean((safe_norm - 1)**2)
     return gradient_mag
Ejemplo n.º 16
0
def get_zero_entities_mask(
        observations: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
    all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
    layer to mask the padding observations.
    """
    with torch.no_grad():
        # Generate the masking tensors for each entities tensor (mask only if all zeros)
        key_masks: List[torch.Tensor] = [
            (torch.sum(ent**2, axis=2) < 0.01).float() for ent in observations
        ]
    return key_masks
Ejemplo n.º 17
0
 def compute_loss(
     self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
 ) -> torch.Tensor:
     """
     Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
     """
     total_loss = torch.zeros(1)
     stats_dict: Dict[str, np.ndarray] = {}
     policy_estimate, policy_mu = self.compute_estimate(
         policy_batch, use_vail_noise=True
     )
     expert_estimate, expert_mu = self.compute_estimate(
         expert_batch, use_vail_noise=True
     )
     stats_dict["Policy/GAIL Policy Estimate"] = policy_estimate.mean().item()
     stats_dict["Policy/GAIL Expert Estimate"] = expert_estimate.mean().item()
     discriminator_loss = -(
         torch.log(expert_estimate + self.EPSILON)
         + torch.log(1.0 - policy_estimate + self.EPSILON)
     ).mean()
     stats_dict["Losses/GAIL Loss"] = discriminator_loss.item()
     total_loss += discriminator_loss
     if self._settings.use_vail:
         # KL divergence loss (encourage latent representation to be normal)
         kl_loss = torch.mean(
             -torch.sum(
                 1
                 + (self._z_sigma ** 2).log()
                 - 0.5 * expert_mu ** 2
                 - 0.5 * policy_mu ** 2
                 - (self._z_sigma ** 2),
                 dim=1,
             )
         )
         vail_loss = self._beta * (kl_loss - self.mutual_information)
         with torch.no_grad():
             self._beta.data = torch.max(
                 self._beta + self.alpha * (kl_loss - self.mutual_information),
                 torch.tensor(0.0),
             )
         total_loss += vail_loss
         stats_dict["Policy/GAIL Beta"] = self._beta.item()
         stats_dict["Losses/GAIL KL Loss"] = kl_loss.item()
     if self.gradient_penalty_weight > 0.0:
         gradient_magnitude_loss = (
             self.gradient_penalty_weight
             * self.compute_gradient_magnitude(policy_batch, expert_batch)
         )
         stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item()
         total_loss += gradient_magnitude_loss
     return total_loss, stats_dict
 def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Computes the inverse loss for a mini_batch. Corresponds to the error on the
     action prediction (given the current and next state).
     """
     predicted_action = self.predict_action(mini_batch)
     if self._policy_specs.is_action_continuous():
         sq_difference = (ModelUtils.list_to_tensor(mini_batch["actions"],
                                                    dtype=torch.float) -
                          predicted_action)**2
         sq_difference = torch.sum(sq_difference, dim=1)
         return torch.mean(
             ModelUtils.dynamic_partition(
                 sq_difference,
                 ModelUtils.list_to_tensor(mini_batch["masks"],
                                           dtype=torch.float),
                 2,
             )[1])
     else:
         true_action = torch.cat(
             ModelUtils.actions_to_onehot(
                 ModelUtils.list_to_tensor(mini_batch["actions"],
                                           dtype=torch.long),
                 self._policy_specs.discrete_action_branches,
             ),
             dim=1,
         )
         cross_entropy = torch.sum(
             -torch.log(predicted_action + self.EPSILON) * true_action,
             dim=1)
         return torch.mean(
             ModelUtils.dynamic_partition(
                 cross_entropy,
                 ModelUtils.list_to_tensor(
                     mini_batch["masks"],
                     dtype=torch.float),  # use masks not action_masks
                 2,
             )[1])
Ejemplo n.º 19
0
def test_multi_head_attention_masking():
    epsilon = 0.0001
    n_h, emb_size = 4, 12
    n_k, n_q, b = 13, 14, 15
    mha = MultiHeadAttention(emb_size, n_h)
    # create a key input with some keys all 0
    query = torch.ones((b, n_q, emb_size))
    key = torch.ones((b, n_k, emb_size))
    value = torch.ones((b, n_k, emb_size))

    mask = torch.zeros((b, n_k))
    for i in range(n_k):
        if i % 3 == 0:
            key[:, i, :] = 0
            mask[:, i] = 1

    _, attention = mha.forward(query, key, value, n_q, n_k, mask)

    for i in range(n_k):
        if i % 3 == 0:
            assert torch.sum(attention[:, :, :, i]**2) < epsilon
        else:
            assert torch.sum(attention[:, :, :, i]**2) > epsilon
Ejemplo n.º 20
0
    def sample_actions(
        self,
        vec_obs: List[torch.Tensor],
        vis_obs: List[torch.Tensor],
        masks: Optional[torch.Tensor] = None,
        memories: Optional[torch.Tensor] = None,
        seq_len: int = 1,
        all_log_probs: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor]:
        """
        :param vec_obs: List of vector observations.
        :param vis_obs: List of visual observations.
        :param masks: Loss masks for RNN, else None.
        :param memories: Input memories when using RNN, else None.
        :param seq_len: Sequence length when using RNN.
        :param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
        :return: Tuple of actions, actions clipped to -1, 1, log probabilities (dependent on all_log_probs),
            entropies, and output memories, all as Torch Tensors.
        """
        if memories is None:
            dists, memories = self.actor_critic.get_dists(
                vec_obs, vis_obs, masks, memories, seq_len)
        else:
            # If we're using LSTM. we need to execute the values to get the critic memories
            dists, _, memories = self.actor_critic.get_dist_and_value(
                vec_obs, vis_obs, masks, memories, seq_len)
        action_list = self.actor_critic.sample_action(dists)
        log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
            action_list, dists)
        actions = torch.stack(action_list, dim=-1)
        if self.use_continuous_act:
            actions = actions[:, :, 0]
        else:
            actions = actions[:, 0, :]
        # Use the sum of entropy across actions, not the mean
        entropy_sum = torch.sum(entropies, dim=1)

        if self._clip_action and self.use_continuous_act:
            clipped_action = torch.clamp(actions, -3, 3) / 3
        else:
            clipped_action = actions
        return (
            actions,
            clipped_action,
            all_logs if all_log_probs else log_probs,
            entropy_sum,
            memories,
        )
Ejemplo n.º 21
0
 def evaluate(self, inputs: torch.Tensor, masks: torch.Tensor,
              actions: AgentAction) -> Tuple[ActionLogProbs, torch.Tensor]:
     """
     Given actions and encoding from the network body, gets the distributions and
     computes the log probabilites and entropies.
     :params inputs: The encoding from the network body
     :params masks: Action masks for discrete actions
     :params actions: The AgentAction
     :return: An ActionLogProbs tuple and a torch tensor of the distribution entropies.
     """
     dists = self._get_dists(inputs, masks)
     log_probs, entropies = self._get_probs_and_entropy(actions, dists)
     # Use the sum of entropy across actions, not the mean
     entropy_sum = torch.sum(entropies, dim=1)
     return log_probs, entropy_sum
Ejemplo n.º 22
0
    def _condense_q_streams(
            self, q_output: Dict[str, torch.Tensor],
            discrete_actions: torch.Tensor) -> Dict[str, torch.Tensor]:
        condensed_q_output = {}
        onehot_actions = ModelUtils.actions_to_onehot(discrete_actions,
                                                      self.act_size)
        for key, item in q_output.items():
            branched_q = ModelUtils.break_into_branches(item, self.act_size)
            only_action_qs = torch.stack([
                torch.sum(_act * _q, dim=1, keepdim=True)
                for _act, _q in zip(onehot_actions, branched_q)
            ])

            condensed_q_output[key] = torch.mean(only_action_qs, dim=0)
        return condensed_q_output
Ejemplo n.º 23
0
 def evaluate_actions(
     self,
     vec_obs: torch.Tensor,
     vis_obs: torch.Tensor,
     actions: torch.Tensor,
     masks: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
     seq_len: int = 1,
 ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
     dists, value_heads, _ = self.actor_critic.get_dist_and_value(
         vec_obs, vis_obs, masks, memories, seq_len)
     action_list = [actions[..., i] for i in range(actions.shape[-1])]
     log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(
         action_list, dists)
     # Use the sum of entropy across actions, not the mean
     entropy_sum = torch.sum(entropies, dim=1)
     return log_probs, entropy_sum, value_heads
Ejemplo n.º 24
0
 def forward(
     self, inputs: torch.Tensor, masks: torch.Tensor
 ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor]:
     """
     The forward method of this module. Outputs the action, log probs,
     and entropies given the encoding from the network body.
     :params inputs: The encoding from the network body
     :params masks: Action masks for discrete actions
     :return: Given the input, an AgentAction of the actions generated by the policy and the corresponding
     ActionLogProbs and entropies.
     """
     dists = self._get_dists(inputs, masks)
     actions = self._sample_action(dists)
     log_probs, entropies = self._get_probs_and_entropy(actions, dists)
     # Use the sum of entropy across actions, not the mean
     entropy_sum = torch.sum(entropies, dim=1)
     return (actions, log_probs, entropy_sum)
Ejemplo n.º 25
0
def test_simple_transformer_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
    transformer = ResidualSelfAttention(embedding_size, [n_k])
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(list(transformer.parameters()) +
                                 list(l_layer.parameters()),
                                 lr=0.001)
    batch_size = 200
    point_range = 3
    init_error = -1.0
    for _ in range(250):
        center = torch.rand((batch_size, size)) * point_range * 2 - point_range
        key = torch.rand(
            (batch_size, n_k, size)) * point_range * 2 - point_range
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, [key])
        masks = EntityEmbeddings.get_masks([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:
            assert error.item() < init_error
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.3
Ejemplo n.º 26
0
 def _behavioral_cloning_loss(self, selected_actions, log_probs,
                              expert_actions):
     if self.policy.use_continuous_act:
         bc_loss = torch.nn.functional.mse_loss(selected_actions,
                                                expert_actions)
     else:
         log_prob_branches = ModelUtils.break_into_branches(
             log_probs, self.policy.act_size)
         bc_loss = torch.mean(
             torch.stack([
                 torch.sum(
                     -torch.nn.functional.log_softmax(
                         log_prob_branch, dim=1) * expert_actions_branch,
                     dim=1,
                 ) for log_prob_branch, expert_actions_branch in zip(
                     log_prob_branches, expert_actions)
             ]))
     return bc_loss
Ejemplo n.º 27
0
    def compute_gradient_magnitude(self, policy_batch: AgentBuffer,
                                   expert_batch: AgentBuffer) -> torch.Tensor:
        """
        Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
        for off-policy. Compute gradients w.r.t randomly interpolated input.
        """
        policy_obs = self.get_state_encoding(policy_batch)
        expert_obs = self.get_state_encoding(expert_batch)
        obs_epsilon = torch.rand(policy_obs.shape)
        encoder_input = obs_epsilon * policy_obs + (1 -
                                                    obs_epsilon) * expert_obs
        if self._settings.use_actions:
            policy_action = self.get_action_input(policy_batch)
            expert_action = self.get_action_input(expert_batch)
            action_epsilon = torch.rand(policy_action.shape)
            policy_dones = torch.as_tensor(policy_batch["done"],
                                           dtype=torch.float).unsqueeze(1)
            expert_dones = torch.as_tensor(expert_batch["done"],
                                           dtype=torch.float).unsqueeze(1)
            dones_epsilon = torch.rand(policy_dones.shape)
            encoder_input = torch.cat(
                [
                    encoder_input,
                    action_epsilon * policy_action +
                    (1 - action_epsilon) * expert_action,
                    dones_epsilon * policy_dones +
                    (1 - dones_epsilon) * expert_dones,
                ],
                dim=1,
            )
        hidden = self.encoder(encoder_input)
        if self._settings.use_vail:
            use_vail_noise = True
            z_mu = self._z_mu_layer(hidden)
            hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
        estimate = self._estimator(hidden).squeeze(1).sum()

        gradient = torch.autograd.grad(estimate,
                                       encoder_input,
                                       create_graph=True)[0]
        # Norm's gradient could be NaN at 0. Use our own safe_norm
        safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
        gradient_mag = torch.mean((safe_norm - 1)**2)
        return gradient_mag
Ejemplo n.º 28
0
def test_predict_closest_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 200
    for _ in range(200):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = get_zero_entities_mask([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.02
Ejemplo n.º 29
0
def test_all_masking(mask_value):
    # We make sure that a mask of all zeros or all ones will not trigger an error
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 20
    for _ in range(5):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = [torch.ones_like(key[:, :, 0]) * mask_value]
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
Ejemplo n.º 30
0
def test_multi_head_attention_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_h, n_k, n_q = 3, 10, 5, 1
    embedding_size = 64
    mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size)
    optimizer = torch.optim.Adam(mha.parameters(), lr=0.001)
    batch_size = 200
    point_range = 3
    init_error = -1.0
    for _ in range(50):
        query = torch.rand(
            (batch_size, n_q, size)) * point_range * 2 - point_range
        key = torch.rand(
            (batch_size, n_k, size)) * point_range * 2 - point_range
        value = key
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((query - key)**2, dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        prediction, _ = mha.forward(query, key, value)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:
            assert error.item() < init_error
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.5