Пример #1
0
def test_continuous_action_prediction(behavior_spec: BehaviorSpec,
                                      seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    curiosity_settings = CuriositySettings(32, 0.1)
    curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
    buffer = create_agent_buffer(behavior_spec, 5)
    for _ in range(200):
        curiosity_rp.update(buffer)
    prediction = curiosity_rp._network.predict_action(buffer)[0]
    target = torch.tensor(buffer["actions"][0])
    error = torch.mean((prediction - target)**2).item()
    assert error < 0.001
Пример #2
0
 def compute_gradient_magnitude(self, policy_batch: AgentBuffer,
                                expert_batch: AgentBuffer) -> torch.Tensor:
     """
     Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
     for off-policy. Compute gradients w.r.t randomly interpolated input.
     """
     policy_inputs = self.get_state_inputs(policy_batch)
     expert_inputs = self.get_state_inputs(expert_batch)
     interp_inputs = []
     for policy_input, expert_input in zip(policy_inputs, expert_inputs):
         obs_epsilon = torch.rand(policy_input.shape)
         interp_input = obs_epsilon * policy_input + (
             1 - obs_epsilon) * expert_input
         interp_input.requires_grad = True  # For gradient calculation
         interp_inputs.append(interp_input)
     if self._settings.use_actions:
         policy_action = self.get_action_input(policy_batch)
         expert_action = self.get_action_input(expert_batch)
         action_epsilon = torch.rand(policy_action.shape)
         policy_dones = torch.as_tensor(policy_batch["done"],
                                        dtype=torch.float).unsqueeze(1)
         expert_dones = torch.as_tensor(expert_batch["done"],
                                        dtype=torch.float).unsqueeze(1)
         dones_epsilon = torch.rand(policy_dones.shape)
         action_inputs = torch.cat(
             [
                 action_epsilon * policy_action +
                 (1 - action_epsilon) * expert_action,
                 dones_epsilon * policy_dones +
                 (1 - dones_epsilon) * expert_dones,
             ],
             dim=1,
         )
         action_inputs.requires_grad = True
         hidden, _ = self.encoder(interp_inputs, action_inputs)
         encoder_input = tuple(interp_inputs + [action_inputs])
     else:
         hidden, _ = self.encoder(interp_inputs)
         encoder_input = tuple(interp_inputs)
     if self._settings.use_vail:
         use_vail_noise = True
         z_mu = self._z_mu_layer(hidden)
         hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
     estimate = self._estimator(hidden).squeeze(1).sum()
     gradient = torch.autograd.grad(estimate,
                                    encoder_input,
                                    create_graph=True)[0]
     # Norm's gradient could be NaN at 0. Use our own safe_norm
     safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
     gradient_mag = torch.mean((safe_norm - 1)**2)
     return gradient_mag
Пример #3
0
 def compute_loss(
     self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
 ) -> torch.Tensor:
     """
     Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
     """
     total_loss = torch.zeros(1)
     stats_dict: Dict[str, np.ndarray] = {}
     policy_estimate, policy_mu = self.compute_estimate(
         policy_batch, use_vail_noise=True
     )
     expert_estimate, expert_mu = self.compute_estimate(
         expert_batch, use_vail_noise=True
     )
     stats_dict["Policy/GAIL Policy Estimate"] = policy_estimate.mean().item()
     stats_dict["Policy/GAIL Expert Estimate"] = expert_estimate.mean().item()
     discriminator_loss = -(
         torch.log(expert_estimate + self.EPSILON)
         + torch.log(1.0 - policy_estimate + self.EPSILON)
     ).mean()
     stats_dict["Losses/GAIL Loss"] = discriminator_loss.item()
     total_loss += discriminator_loss
     if self._settings.use_vail:
         # KL divergence loss (encourage latent representation to be normal)
         kl_loss = torch.mean(
             -torch.sum(
                 1
                 + (self._z_sigma ** 2).log()
                 - 0.5 * expert_mu ** 2
                 - 0.5 * policy_mu ** 2
                 - (self._z_sigma ** 2),
                 dim=1,
             )
         )
         vail_loss = self._beta * (kl_loss - self.mutual_information)
         with torch.no_grad():
             self._beta.data = torch.max(
                 self._beta + self.alpha * (kl_loss - self.mutual_information),
                 torch.tensor(0.0),
             )
         total_loss += vail_loss
         stats_dict["Policy/GAIL Beta"] = self._beta.item()
         stats_dict["Losses/GAIL KL Loss"] = kl_loss.item()
     if self.gradient_penalty_weight > 0.0:
         gradient_magnitude_loss = (
             self.gradient_penalty_weight
             * self.compute_gradient_magnitude(policy_batch, expert_batch)
         )
         stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item()
         total_loss += gradient_magnitude_loss
     return total_loss, stats_dict
 def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Computes the inverse loss for a mini_batch. Corresponds to the error on the
     action prediction (given the current and next state).
     """
     predicted_action = self.predict_action(mini_batch)
     if self._policy_specs.is_action_continuous():
         sq_difference = (ModelUtils.list_to_tensor(mini_batch["actions"],
                                                    dtype=torch.float) -
                          predicted_action)**2
         sq_difference = torch.sum(sq_difference, dim=1)
         return torch.mean(
             ModelUtils.dynamic_partition(
                 sq_difference,
                 ModelUtils.list_to_tensor(mini_batch["masks"],
                                           dtype=torch.float),
                 2,
             )[1])
     else:
         true_action = torch.cat(
             ModelUtils.actions_to_onehot(
                 ModelUtils.list_to_tensor(mini_batch["actions"],
                                           dtype=torch.long),
                 self._policy_specs.discrete_action_branches,
             ),
             dim=1,
         )
         cross_entropy = torch.sum(
             -torch.log(predicted_action + self.EPSILON) * true_action,
             dim=1)
         return torch.mean(
             ModelUtils.dynamic_partition(
                 cross_entropy,
                 ModelUtils.list_to_tensor(
                     mini_batch["masks"],
                     dtype=torch.float),  # use masks not action_masks
                 2,
             )[1])
Пример #5
0
    def _condense_q_streams(
            self, q_output: Dict[str, torch.Tensor],
            discrete_actions: torch.Tensor) -> Dict[str, torch.Tensor]:
        condensed_q_output = {}
        onehot_actions = ModelUtils.actions_to_onehot(discrete_actions,
                                                      self.act_size)
        for key, item in q_output.items():
            branched_q = ModelUtils.break_into_branches(item, self.act_size)
            only_action_qs = torch.stack([
                torch.sum(_act * _q, dim=1, keepdim=True)
                for _act, _q in zip(onehot_actions, branched_q)
            ])

            condensed_q_output[key] = torch.mean(only_action_qs, dim=0)
        return condensed_q_output
Пример #6
0
    def sac_policy_loss(
        self,
        log_probs: ActionLogProbs,
        q1p_outs: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        _cont_ent_coef = _cont_ent_coef.exp()
        _disc_ent_coef = _disc_ent_coef.exp()

        mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
        batch_policy_loss = 0
        if self._action_spec.discrete_size > 0:
            disc_log_probs = log_probs.all_discrete_tensor
            disc_action_probs = disc_log_probs.exp()
            branched_per_action_ent = ModelUtils.break_into_branches(
                disc_log_probs * disc_action_probs,
                self._action_spec.discrete_branches)
            branched_q_term = ModelUtils.break_into_branches(
                mean_q1 * disc_action_probs,
                self._action_spec.discrete_branches)
            branched_policy_loss = torch.stack(
                [
                    torch.sum(
                        _disc_ent_coef[i] * _lp - _qt, dim=1, keepdim=False)
                    for i, (_lp, _qt) in enumerate(
                        zip(branched_per_action_ent, branched_q_term))
                ],
                dim=1,
            )
            batch_policy_loss += torch.sum(branched_policy_loss, dim=1)
            all_mean_q1 = torch.sum(disc_action_probs * mean_q1, dim=1)
        else:
            all_mean_q1 = mean_q1
        if self._action_spec.continuous_size > 0:
            cont_log_probs = log_probs.continuous_tensor
            batch_policy_loss += (
                _cont_ent_coef * torch.sum(cont_log_probs, dim=1) -
                all_mean_q1)
        policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)

        return policy_loss
Пример #7
0
 def _behavioral_cloning_loss(self, selected_actions, log_probs,
                              expert_actions):
     if self.policy.use_continuous_act:
         bc_loss = torch.nn.functional.mse_loss(selected_actions,
                                                expert_actions)
     else:
         log_prob_branches = ModelUtils.break_into_branches(
             log_probs, self.policy.act_size)
         bc_loss = torch.mean(
             torch.stack([
                 torch.sum(
                     -torch.nn.functional.log_softmax(
                         log_prob_branch, dim=1) * expert_actions_branch,
                     dim=1,
                 ) for log_prob_branch, expert_actions_branch in zip(
                     log_prob_branches, expert_actions)
             ]))
     return bc_loss
    def sac_entropy_loss(
        self, log_probs: ActionLogProbs, loss_masks: torch.Tensor
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        entropy_loss = 0
        if self._action_spec.discrete_size > 0:
            with torch.no_grad():
                # Break continuous into separate branch
                disc_log_probs = log_probs.all_discrete_tensor
                branched_per_action_ent = ModelUtils.break_into_branches(
                    disc_log_probs * disc_log_probs.exp(),
                    self._action_spec.discrete_branches,
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy.discrete
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss += -1 * ModelUtils.masked_mean(
                torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks
            )
        if self._action_spec.continuous_size > 0:
            with torch.no_grad():
                cont_log_probs = log_probs.continuous_tensor
                target_current_diff = torch.sum(
                    cont_log_probs + self.target_entropy.continuous, dim=1
                )
            # We update all the _cont_ent_coef as one block
            entropy_loss += -1 * ModelUtils.masked_mean(
                _cont_ent_coef * target_current_diff, loss_masks
            )

        return entropy_loss
Пример #9
0
def test_visual_encoder_trains(vis_class, size):
    torch.manual_seed(0)
    image_size = (size, size, 1)
    batch = 100

    inputs = torch.cat([
        torch.zeros((batch, ) + image_size),
        torch.ones((batch, ) + image_size)
    ],
                       dim=0)
    target = torch.cat([torch.zeros((batch, )), torch.ones((batch, ))], dim=0)
    enc = vis_class(image_size[0], image_size[1], image_size[2], 1)
    optimizer = torch.optim.Adam(enc.parameters(), lr=0.001)

    for _ in range(15):
        prediction = enc(inputs)[:, 0]
        loss = torch.mean((target - prediction)**2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    assert loss.item() < 0.05
Пример #10
0
 def forward(self, layer_activations: torch.Tensor) -> torch.Tensor:
     mean = torch.mean(layer_activations, dim=-1, keepdim=True)
     var = torch.mean((layer_activations - mean)**2, dim=-1, keepdim=True)
     return (layer_activations - mean) / (torch.sqrt(var + 1e-5))
Пример #11
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

        n_obs = len(self.policy.behavior_spec.sensor_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        next_obs = ObsUtil.from_buffer_next(batch, n_obs)
        # Convert to tensors
        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        actions = AgentAction.from_dict(batch)

        memories_list = [
            ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(
                0, len(batch["memory"]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        offset = 1 if self.policy.sequence_length > 1 else 0
        next_memories_list = [
            ModelUtils.list_to_tensor(
                batch["memory"][i]
                [self.policy.m_size //
                 2:])  # only pass value part of memory to target network
            for i in range(offset, len(batch["memory"]),
                           self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            next_memories = torch.stack(next_memories_list).unsqueeze(0)
        else:
            memories = None
            next_memories = None
        # Q network memories are 0'ed out, since we don't have them during inference.
        q_memories = (torch.zeros_like(next_memories)
                      if next_memories is not None else None)

        # Copy normalizers from policy
        self.value_network.q1_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        self.value_network.q2_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        self.target_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        (
            sampled_actions,
            log_probs,
            _,
            value_estimates,
            _,
        ) = self.policy.actor_critic.get_action_stats_and_value(
            current_obs,
            masks=act_masks,
            memories=memories,
            sequence_length=self.policy.sequence_length,
        )

        cont_sampled_actions = sampled_actions.continuous_tensor
        cont_actions = actions.continuous_tensor
        q1p_out, q2p_out = self.value_network(
            current_obs,
            cont_sampled_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
            q2_grad=False,
        )
        q1_out, q2_out = self.value_network(
            current_obs,
            cont_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
        )

        if self._action_spec.discrete_size > 0:
            disc_actions = actions.discrete_tensor
            q1_stream = self._condense_q_streams(q1_out, disc_actions)
            q2_stream = self._condense_q_streams(q2_out, disc_actions)
        else:
            q1_stream, q2_stream = q1_out, q2_out

        with torch.no_grad():
            target_values, _ = self.target_network(
                next_obs,
                memories=next_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
        dones = ModelUtils.list_to_tensor(batch["done"])

        q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values,
                                           dones, rewards, masks)
        value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out,
                                         q2p_out, masks)
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
        entropy_loss = self.sac_entropy_loss(log_probs, masks)

        total_value_loss = q1_loss + q2_loss + value_loss

        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(self.policy.actor_critic.critic,
                               self.target_network, self.tau)
        update_stats = {
            "Losses/Policy Loss":
            policy_loss.item(),
            "Losses/Value Loss":
            value_loss.item(),
            "Losses/Q1 Loss":
            q1_loss.item(),
            "Losses/Q2 Loss":
            q2_loss.item(),
            "Policy/Discrete Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.discrete)).item(),
            "Policy/Continuous Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.continuous)).item(),
            "Policy/Learning Rate":
            decay_lr,
        }

        return update_stats
Пример #12
0
    def sac_value_loss(
        self,
        log_probs: ActionLogProbs,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _cont_ent_coef = self._log_ent_coef.continuous.exp()
            _disc_ent_coef = self._log_ent_coef.discrete.exp()
            for name in values.keys():
                if self._action_spec.discrete_size <= 0:
                    min_policy_qs[name] = torch.min(q1p_out[name],
                                                    q2p_out[name])
                else:
                    disc_action_probs = log_probs.all_discrete_tensor.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * disc_action_probs,
                        self._action_spec.discrete_branches,
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * disc_action_probs,
                        self._action_spec.discrete_branches,
                    )
                    _q1p_mean = torch.mean(
                        torch.stack([
                            torch.sum(_br, dim=1, keepdim=True)
                            for _br in _branched_q1p
                        ]),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack([
                            torch.sum(_br, dim=1, keepdim=True)
                            for _br in _branched_q2p
                        ]),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if self._action_spec.discrete_size <= 0:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _cont_ent_coef * log_probs.continuous_tensor, dim=1)
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup),
                    loss_masks)
                value_losses.append(value_loss)
        else:
            disc_log_probs = log_probs.all_discrete_tensor
            branched_per_action_ent = ModelUtils.break_into_branches(
                disc_log_probs * disc_log_probs.exp(),
                self._action_spec.discrete_branches,
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack([
                torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True)
                for i, _lp in enumerate(branched_per_action_ent)
            ])
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0)
                    # Add continuous entropy bonus to minimum Q
                    if self._action_spec.continuous_size > 0:
                        v_backup += torch.sum(
                            _cont_ent_coef * log_probs.continuous_tensor,
                            dim=1,
                            keepdim=True,
                        )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name],
                                                 v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss
Пример #13
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(
                batch[RewardSignalUtil.rewards_key(name)])

        n_obs = len(self.policy.behavior_spec.observation_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        next_obs = ObsUtil.from_buffer_next(batch, n_obs)
        # Convert to tensors
        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

        act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
        actions = AgentAction.from_buffer(batch)

        memories_list = [
            ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in
            range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        value_memories_list = [
            ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
            for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]),
                           self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            value_memories = torch.stack(value_memories_list).unsqueeze(0)
        else:
            memories = None
            value_memories = None

        # Q and V network memories are 0'ed out, since we don't have them during inference.
        q_memories = (torch.zeros_like(value_memories)
                      if value_memories is not None else None)

        # Copy normalizers from policy
        self.q_network.q1_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self.q_network.q2_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self.target_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self._critic.network_body.copy_normalization(
            self.policy.actor.network_body)
        sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats(
            current_obs,
            masks=act_masks,
            memories=memories,
            sequence_length=self.policy.sequence_length,
        )
        value_estimates, _ = self._critic.critic_pass(
            current_obs,
            value_memories,
            sequence_length=self.policy.sequence_length)

        cont_sampled_actions = sampled_actions.continuous_tensor
        cont_actions = actions.continuous_tensor
        q1p_out, q2p_out = self.q_network(
            current_obs,
            cont_sampled_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
            q2_grad=False,
        )
        q1_out, q2_out = self.q_network(
            current_obs,
            cont_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
        )

        if self._action_spec.discrete_size > 0:
            disc_actions = actions.discrete_tensor
            q1_stream = self._condense_q_streams(q1_out, disc_actions)
            q2_stream = self._condense_q_streams(q2_out, disc_actions)
        else:
            q1_stream, q2_stream = q1_out, q2_out

        with torch.no_grad():
            # Since we didn't record the next value memories, evaluate one step in the critic to
            # get them.
            if value_memories is not None:
                # Get the first observation in each sequence
                just_first_obs = [
                    _obs[::self.policy.sequence_length] for _obs in current_obs
                ]
                _, next_value_memories = self._critic.critic_pass(
                    just_first_obs, value_memories, sequence_length=1)
            else:
                next_value_memories = None
            target_values, _ = self.target_network(
                next_obs,
                memories=next_value_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS],
                                          dtype=torch.bool)
        dones = ModelUtils.list_to_tensor(batch[BufferKey.DONE])

        q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values,
                                           dones, rewards, masks)
        value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out,
                                         q2p_out, masks)
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
        entropy_loss = self.sac_entropy_loss(log_probs, masks)

        total_value_loss = q1_loss + q2_loss
        if self.policy.shared_critic:
            policy_loss += value_loss
        else:
            total_value_loss += value_loss

        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(self._critic, self.target_network, self.tau)
        update_stats = {
            "Losses/Policy Loss":
            policy_loss.item(),
            "Losses/Value Loss":
            value_loss.item(),
            "Losses/Q1 Loss":
            q1_loss.item(),
            "Losses/Q2 Loss":
            q2_loss.item(),
            "Policy/Discrete Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.discrete)).item(),
            "Policy/Continuous Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.continuous)).item(),
            "Policy/Learning Rate":
            decay_lr,
        }

        return update_stats
Пример #14
0
    def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

        vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
        next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])]
        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        if self.policy.use_continuous_act:
            actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
        else:
            actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)

        memories_list = [
            ModelUtils.list_to_tensor(batch["memory"][i])
            for i in range(0, len(batch["memory"]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        offset = 1 if self.policy.sequence_length > 1 else 0
        next_memories_list = [
            ModelUtils.list_to_tensor(
                batch["memory"][i][self.policy.m_size // 2 :]
            )  # only pass value part of memory to target network
            for i in range(offset, len(batch["memory"]), self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            next_memories = torch.stack(next_memories_list).unsqueeze(0)
        else:
            memories = None
            next_memories = None
        # Q network memories are 0'ed out, since we don't have them during inference.
        q_memories = (
            torch.zeros_like(next_memories) if next_memories is not None else None
        )

        vis_obs: List[torch.Tensor] = []
        next_vis_obs: List[torch.Tensor] = []
        if self.policy.use_vis_obs:
            vis_obs = []
            for idx, _ in enumerate(
                self.policy.actor_critic.network_body.visual_processors
            ):
                vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
                vis_obs.append(vis_ob)
                next_vis_ob = ModelUtils.list_to_tensor(
                    batch["next_visual_obs%d" % idx]
                )
                next_vis_obs.append(next_vis_ob)

        # Copy normalizers from policy
        self.value_network.q1_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body
        )
        self.value_network.q2_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body
        )
        self.target_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body
        )
        (sampled_actions, _, log_probs, _, _) = self.policy.sample_actions(
            vec_obs,
            vis_obs,
            masks=act_masks,
            memories=memories,
            seq_len=self.policy.sequence_length,
            all_log_probs=not self.policy.use_continuous_act,
        )
        value_estimates, _ = self.policy.actor_critic.critic_pass(
            vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length
        )
        if self.policy.use_continuous_act:
            squeezed_actions = actions.squeeze(-1)
            # Only need grad for q1, as that is used for policy.
            q1p_out, q2p_out = self.value_network(
                vec_obs,
                vis_obs,
                sampled_actions,
                memories=q_memories,
                sequence_length=self.policy.sequence_length,
                q2_grad=False,
            )
            q1_out, q2_out = self.value_network(
                vec_obs,
                vis_obs,
                squeezed_actions,
                memories=q_memories,
                sequence_length=self.policy.sequence_length,
            )
            q1_stream, q2_stream = q1_out, q2_out
        else:
            # For discrete, you don't need to backprop through the Q for the policy
            q1p_out, q2p_out = self.value_network(
                vec_obs,
                vis_obs,
                memories=q_memories,
                sequence_length=self.policy.sequence_length,
                q1_grad=False,
                q2_grad=False,
            )
            q1_out, q2_out = self.value_network(
                vec_obs,
                vis_obs,
                memories=q_memories,
                sequence_length=self.policy.sequence_length,
            )
            q1_stream = self._condense_q_streams(q1_out, actions)
            q2_stream = self._condense_q_streams(q2_out, actions)

        with torch.no_grad():
            target_values, _ = self.target_network(
                next_vec_obs,
                next_vis_obs,
                memories=next_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
        use_discrete = not self.policy.use_continuous_act
        dones = ModelUtils.list_to_tensor(batch["done"])

        q1_loss, q2_loss = self.sac_q_loss(
            q1_stream, q2_stream, target_values, dones, rewards, masks
        )
        value_loss = self.sac_value_loss(
            log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete
        )
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete)
        entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)

        total_value_loss = q1_loss + q2_loss + value_loss

        decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(
            self.policy.actor_critic.critic, self.target_network, self.tau
        )
        update_stats = {
            "Losses/Policy Loss": policy_loss.item(),
            "Losses/Value Loss": value_loss.item(),
            "Losses/Q1 Loss": q1_loss.item(),
            "Losses/Q2 Loss": q2_loss.item(),
            "Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(),
            "Policy/Learning Rate": decay_lr,
        }

        return update_stats
Пример #15
0
    def sac_value_loss(
        self,
        log_probs: torch.Tensor,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
        discrete: bool,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _ent_coef = torch.exp(self._log_ent_coef)
            for name in values.keys():
                if not discrete:
                    min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
                else:
                    action_probs = log_probs.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * action_probs, self.act_size
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * action_probs, self.act_size
                    )
                    _q1p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q1p
                            ]
                        ),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q2p
                            ]
                        ),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if not discrete:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _ent_coef * log_probs, dim=1
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
                )
                value_losses.append(value_loss)
        else:
            branched_per_action_ent = ModelUtils.break_into_branches(
                log_probs * log_probs.exp(), self.act_size
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack(
                [
                    torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
                    for i, _lp in enumerate(branched_per_action_ent)
                ]
            )
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss
Пример #16
0
 def entropy(self):
     return torch.mean(
         0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON),
         dim=1,
         keepdim=True,
     )  # Use equivalent behavior to TF