Exemple #1
0
    def sac_entropy_loss(
        self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool
    ) -> torch.Tensor:
        if not discrete:
            with torch.no_grad():
                target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
            entropy_loss = -1 * ModelUtils.masked_mean(
                self._log_ent_coef * target_current_diff, loss_masks
            )
        else:
            with torch.no_grad():
                branched_per_action_ent = ModelUtils.break_into_branches(
                    log_probs * log_probs.exp(), self.act_size
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss = -1 * ModelUtils.masked_mean(
                torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
            )

        return entropy_loss
Exemple #2
0
def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
    """
    Make sure two policies have the same output for the same input.
    """
    decision_step, _ = mb.create_steps_from_behavior_spec(
        policy1.behavior_spec, num_agents=1)
    vec_vis_obs, masks = policy1._split_decision_step(decision_step)
    vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)]
    vis_obs = [
        torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations
    ]
    memories = torch.as_tensor(
        policy1.retrieve_memories(list(decision_step.agent_id))).unsqueeze(0)

    with torch.no_grad():
        _, log_probs1, _, _, _ = policy1.sample_actions(vec_obs,
                                                        vis_obs,
                                                        masks=masks,
                                                        memories=memories,
                                                        all_log_probs=True)
        _, log_probs2, _, _, _ = policy2.sample_actions(vec_obs,
                                                        vis_obs,
                                                        masks=masks,
                                                        memories=memories,
                                                        all_log_probs=True)

    np.testing.assert_array_equal(log_probs1, log_probs2)
Exemple #3
0
    def sac_q_loss(
        self,
        q1_out: Dict[str, torch.Tensor],
        q2_out: Dict[str, torch.Tensor],
        target_values: Dict[str, torch.Tensor],
        dones: torch.Tensor,
        rewards: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q1_losses = []
        q2_losses = []
        # Multiple q losses per stream
        for i, name in enumerate(q1_out.keys()):
            q1_stream = q1_out[name].squeeze()
            q2_stream = q2_out[name].squeeze()
            with torch.no_grad():
                q_backup = rewards[name] + (
                    (1.0 - self.use_dones_in_backup[name] * dones) *
                    self.gammas[i] * target_values[name])
            _q1_loss = 0.5 * ModelUtils.masked_mean(
                torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks)
            _q2_loss = 0.5 * ModelUtils.masked_mean(
                torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks)

            q1_losses.append(_q1_loss)
            q2_losses.append(_q2_loss)
        q1_loss = torch.mean(torch.stack(q1_losses))
        q2_loss = torch.mean(torch.stack(q2_losses))
        return q1_loss, q2_loss
Exemple #4
0
    def evaluate(self, decision_requests: DecisionSteps,
                 global_agent_ids: List[str]) -> Dict[str, Any]:
        """
        Evaluates policy for the agent experiences provided.
        :param global_agent_ids:
        :param decision_requests: DecisionStep object containing inputs.
        :return: Outputs from network as defined by self.inference_dict.
        """
        vec_vis_obs, masks = self._split_decision_step(decision_requests)
        vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)]
        vis_obs = [
            torch.as_tensor(vis_ob)
            for vis_ob in vec_vis_obs.visual_observations
        ]
        memories = torch.as_tensor(
            self.retrieve_memories(global_agent_ids)).unsqueeze(0)

        run_out = {}
        with torch.no_grad():
            action, log_probs, entropy, memories = self.sample_actions(
                vec_obs, vis_obs, masks=masks, memories=memories)
        run_out["action"] = ModelUtils.to_numpy(action)
        run_out["pre_action"] = ModelUtils.to_numpy(action)
        # Todo - make pre_action difference
        run_out["log_probs"] = ModelUtils.to_numpy(log_probs)
        run_out["entropy"] = ModelUtils.to_numpy(entropy)
        run_out["learning_rate"] = 0.0
        if self.use_recurrent:
            run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0)
        return run_out
Exemple #5
0
    def evaluate(self, decision_requests: DecisionSteps,
                 global_agent_ids: List[str]) -> Dict[str, Any]:
        """
        Evaluates policy for the agent experiences provided.
        :param global_agent_ids:
        :param decision_requests: DecisionStep object containing inputs.
        :return: Outputs from network as defined by self.inference_dict.
        """
        obs = decision_requests.obs
        masks = self._extract_masks(decision_requests)
        tensor_obs = [torch.as_tensor(np_ob) for np_ob in obs]

        memories = torch.as_tensor(
            self.retrieve_memories(global_agent_ids)).unsqueeze(0)

        run_out = {}
        with torch.no_grad():
            action, log_probs, entropy, memories = self.sample_actions(
                tensor_obs, masks=masks, memories=memories)
        action_tuple = action.to_action_tuple()
        run_out["action"] = action_tuple
        # This is the clipped action which is not saved to the buffer
        # but is exclusively sent to the environment.
        env_action_tuple = action.to_action_tuple(clip=self._clip_action)
        run_out["env_action"] = env_action_tuple
        run_out["log_probs"] = log_probs.to_log_probs_tuple()
        run_out["entropy"] = ModelUtils.to_numpy(entropy)
        run_out["learning_rate"] = 0.0
        if self.use_recurrent:
            run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0)
        return run_out
 def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
     with torch.no_grad():
         estimates, _ = self._discriminator_network.compute_estimate(
             mini_batch, use_vail_noise=False)
         return ModelUtils.to_numpy(
             -torch.log(1.0 - estimates.squeeze(dim=1) *
                        (1.0 - self._discriminator_network.EPSILON)))
Exemple #7
0
def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
    """
    Make sure two policies have the same output for the same input.
    """
    policy1.actor = policy1.actor.to(default_device())
    policy2.actor = policy2.actor.to(default_device())

    decision_step, _ = mb.create_steps_from_behavior_spec(
        policy1.behavior_spec, num_agents=1)
    np_obs = decision_step.obs
    masks = policy1._extract_masks(decision_step)
    memories = torch.as_tensor(
        policy1.retrieve_memories(list(decision_step.agent_id))).unsqueeze(0)
    tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]

    with torch.no_grad():
        _, log_probs1, _, _ = policy1.sample_actions(tensor_obs,
                                                     masks=masks,
                                                     memories=memories)
        _, log_probs2, _, _ = policy2.sample_actions(tensor_obs,
                                                     masks=masks,
                                                     memories=memories)
    np.testing.assert_array_equal(
        ModelUtils.to_numpy(log_probs1.all_discrete_tensor),
        ModelUtils.to_numpy(log_probs2.all_discrete_tensor),
    )
Exemple #8
0
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
    all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
    layer to mask the padding observations.
    """
    with torch.no_grad():

        if exporting_to_onnx.is_exporting():
            with warnings.catch_warnings():
                # We ignore a TracerWarning from PyTorch that warns that doing
                # shape[n].item() will cause the trace to be incorrect (the trace might
                # not generalize to other inputs)
                # We ignore this warning because we know the model will always be
                # run with inputs of the same shape
                warnings.simplefilter("ignore")
                # When exporting to ONNX, we want to transpose the entities. This is
                # because ONNX only support input in NCHW (channel first) format.
                # Barracuda also expect to get data in NCHW.
                entities = [
                    torch.transpose(obs, 2, 1).reshape(-1, obs.shape[1].item(),
                                                       obs.shape[2].item())
                    for obs in entities
                ]

        # Generate the masking tensors for each entities tensor (mask only if all zeros)
        key_masks: List[torch.Tensor] = [
            (torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities
        ]
    return key_masks
Exemple #9
0
 def forward(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     actions: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
     sequence_length: int = 1,
     q1_grad: bool = True,
     q2_grad: bool = True,
 ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
     """
     Performs a forward pass on the value network, which consists of a Q1 and Q2
     network. Optionally does not evaluate gradients for either the Q1, Q2, or both.
     :param vec_inputs: List of vector observation tensors.
     :param vis_input: List of visual observation tensors.
     :param actions: For a continuous Q function (has actions), tensor of actions.
         Otherwise, None.
     :param memories: Initial memories if using memory. Otherwise, None.
     :param sequence_length: Sequence length if using memory.
     :param q1_grad: Whether or not to compute gradients for the Q1 network.
     :param q2_grad: Whether or not to compute gradients for the Q2 network.
     :return: Tuple of two dictionaries, which both map {reward_signal: Q} for Q1 and Q2,
         respectively.
     """
     # ExitStack allows us to enter the torch.no_grad() context conditionally
     with ExitStack() as stack:
         if not q1_grad:
             stack.enter_context(torch.no_grad())
         q1_out, _ = self.q1_network(
             vec_inputs,
             vis_inputs,
             actions=actions,
             memories=memories,
             sequence_length=sequence_length,
         )
     with ExitStack() as stack:
         if not q2_grad:
             stack.enter_context(torch.no_grad())
         q2_out, _ = self.q2_network(
             vec_inputs,
             vis_inputs,
             actions=actions,
             memories=memories,
             sequence_length=sequence_length,
         )
     return q1_out, q2_out
Exemple #10
0
 def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
     with torch.no_grad():
         target = self._random_network(mini_batch)
     prediction = self._training_network(mini_batch)
     loss = torch.mean(torch.sum((prediction - target)**2, dim=1))
     self.optimizer.zero_grad()
     loss.backward()
     self.optimizer.step()
     return {"Losses/RND Loss": loss.detach().cpu().numpy()}
    def sac_entropy_loss(
        self, log_probs: ActionLogProbs, loss_masks: torch.Tensor
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        entropy_loss = 0
        if self._action_spec.discrete_size > 0:
            with torch.no_grad():
                # Break continuous into separate branch
                disc_log_probs = log_probs.all_discrete_tensor
                branched_per_action_ent = ModelUtils.break_into_branches(
                    disc_log_probs * disc_log_probs.exp(),
                    self._action_spec.discrete_branches,
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy.discrete
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss += -1 * ModelUtils.masked_mean(
                torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks
            )
        if self._action_spec.continuous_size > 0:
            with torch.no_grad():
                cont_log_probs = log_probs.continuous_tensor
                target_current_diff = torch.sum(
                    cont_log_probs + self.target_entropy.continuous, dim=1
                )
            # We update all the _cont_ent_coef as one block
            entropy_loss += -1 * ModelUtils.masked_mean(
                _cont_ent_coef * target_current_diff, loss_masks
            )

        return entropy_loss
def get_zero_entities_mask(
        observations: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
    all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
    layer to mask the padding observations.
    """
    with torch.no_grad():
        # Generate the masking tensors for each entities tensor (mask only if all zeros)
        key_masks: List[torch.Tensor] = [
            (torch.sum(ent**2, axis=2) < 0.01).float() for ent in observations
        ]
    return key_masks
Exemple #13
0
 def compute_loss(
     self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
 ) -> torch.Tensor:
     """
     Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
     """
     total_loss = torch.zeros(1)
     stats_dict: Dict[str, np.ndarray] = {}
     policy_estimate, policy_mu = self.compute_estimate(
         policy_batch, use_vail_noise=True
     )
     expert_estimate, expert_mu = self.compute_estimate(
         expert_batch, use_vail_noise=True
     )
     stats_dict["Policy/GAIL Policy Estimate"] = policy_estimate.mean().item()
     stats_dict["Policy/GAIL Expert Estimate"] = expert_estimate.mean().item()
     discriminator_loss = -(
         torch.log(expert_estimate + self.EPSILON)
         + torch.log(1.0 - policy_estimate + self.EPSILON)
     ).mean()
     stats_dict["Losses/GAIL Loss"] = discriminator_loss.item()
     total_loss += discriminator_loss
     if self._settings.use_vail:
         # KL divergence loss (encourage latent representation to be normal)
         kl_loss = torch.mean(
             -torch.sum(
                 1
                 + (self._z_sigma ** 2).log()
                 - 0.5 * expert_mu ** 2
                 - 0.5 * policy_mu ** 2
                 - (self._z_sigma ** 2),
                 dim=1,
             )
         )
         vail_loss = self._beta * (kl_loss - self.mutual_information)
         with torch.no_grad():
             self._beta.data = torch.max(
                 self._beta + self.alpha * (kl_loss - self.mutual_information),
                 torch.tensor(0.0),
             )
         total_loss += vail_loss
         stats_dict["Policy/GAIL Beta"] = self._beta.item()
         stats_dict["Losses/GAIL KL Loss"] = kl_loss.item()
     if self.gradient_penalty_weight > 0.0:
         gradient_magnitude_loss = (
             self.gradient_penalty_weight
             * self.compute_gradient_magnitude(policy_batch, expert_batch)
         )
         stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item()
         total_loss += gradient_magnitude_loss
     return total_loss, stats_dict
Exemple #14
0
    def update(self, vector_input: torch.Tensor) -> None:
        with torch.no_grad():
            steps_increment = vector_input.size()[0]
            total_new_steps = self.normalization_steps + steps_increment

            input_to_old_mean = vector_input - self.running_mean
            new_mean: torch.Tensor = self.running_mean + (
                input_to_old_mean / total_new_steps).sum(0)

            input_to_new_mean = vector_input - new_mean
            new_variance = self.running_variance + (input_to_new_mean *
                                                    input_to_old_mean).sum(0)
            # Update references. This is much faster than in-place data update.
            self.running_mean: torch.Tensor = new_mean
            self.running_variance: torch.Tensor = new_variance
            self.normalization_steps: torch.Tensor = total_new_steps
Exemple #15
0
def test_predict_minimum_training():
    # of 5 numbers, predict index of min
    np.random.seed(1336)
    torch.manual_seed(1336)
    n_k = 5
    size = n_k + 1
    embedding_size = 64
    entity_embeddings = EntityEmbeddings(size, [size],
                                         embedding_size, [n_k],
                                         concat_self=False)
    transformer = ResidualSelfAttention(embedding_size)
    l_layer = LinearEncoder(embedding_size, 2, n_k)
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )

    batch_size = 200
    onehots = ModelUtils.actions_to_onehot(
        torch.range(0, n_k - 1).unsqueeze(1), [n_k])[0]
    onehots = onehots.expand((batch_size, -1, -1))
    losses = []
    for _ in range(400):
        num = np.random.randint(0, n_k)
        inp = torch.rand((batch_size, num + 1, 1))
        with torch.no_grad():
            # create the target : The minimum
            argmin = torch.argmin(inp, dim=1)
            argmin = argmin.squeeze()
            argmin = argmin.detach()
        sliced_oh = onehots[:, :num + 1]
        inp = torch.cat([inp, sliced_oh], dim=2)

        embeddings = entity_embeddings(inp, [inp])
        masks = EntityEmbeddings.get_masks([inp])
        prediction = transformer(embeddings, masks)
        prediction = l_layer(prediction)
        ce = loss(prediction, argmin)
        losses.append(ce.item())
        print(ce.item())
        optimizer.zero_grad()
        ce.backward()
        optimizer.step()
    assert np.array(losses[-20:]).mean() < 0.1
Exemple #16
0
def test_simple_transformer_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
    transformer = ResidualSelfAttention(embedding_size, [n_k])
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(list(transformer.parameters()) +
                                 list(l_layer.parameters()),
                                 lr=0.001)
    batch_size = 200
    point_range = 3
    init_error = -1.0
    for _ in range(250):
        center = torch.rand((batch_size, size)) * point_range * 2 - point_range
        key = torch.rand(
            (batch_size, n_k, size)) * point_range * 2 - point_range
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, [key])
        masks = EntityEmbeddings.get_masks([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:
            assert error.item() < init_error
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.3
Exemple #17
0
def _compare_two_optimizers(opt1: TorchOptimizer,
                            opt2: TorchOptimizer) -> None:
    trajectory = mb.make_fake_trajectory(
        length=10,
        observation_specs=opt1.policy.behavior_spec.observation_specs,
        action_spec=opt1.policy.behavior_spec.action_spec,
        max_step_complete=True,
    )
    with torch.no_grad():
        _, opt1_val_out, _ = opt1.get_trajectory_value_estimates(
            trajectory.to_agentbuffer(), trajectory.next_obs, done=False)
        _, opt2_val_out, _ = opt2.get_trajectory_value_estimates(
            trajectory.to_agentbuffer(), trajectory.next_obs, done=False)

    for opt1_val, opt2_val in zip(opt1_val_out.values(),
                                  opt2_val_out.values()):
        np.testing.assert_array_equal(opt1_val, opt2_val)
def test_predict_closest_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 200
    for _ in range(200):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = get_zero_entities_mask([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.02
def test_all_masking(mask_value):
    # We make sure that a mask of all zeros or all ones will not trigger an error
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 20
    for _ in range(5):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = [torch.ones_like(key[:, :, 0]) * mask_value]
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
Exemple #20
0
 def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None:
     """
     Performs an in-place polyak update of the target module based on the source,
     by a ratio of tau. Note that source and target modules must have the same
     parameters, where:
         target = tau * source + (1-tau) * target
     :param source: Source module whose parameters will be used.
     :param target: Target module whose parameters will be updated.
     :param tau: Percentage of source parameters to use in average. Setting tau to
         1 will copy the source parameters to the target.
     """
     with torch.no_grad():
         for source_param, target_param in zip(source.parameters(),
                                               target.parameters()):
             target_param.data.mul_(1.0 - tau)
             torch.add(
                 target_param.data,
                 source_param.data,
                 alpha=tau,
                 out=target_param.data,
             )
Exemple #21
0
def test_multi_head_attention_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_h, n_k, n_q = 3, 10, 5, 1
    embedding_size = 64
    mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size)
    optimizer = torch.optim.Adam(mha.parameters(), lr=0.001)
    batch_size = 200
    point_range = 3
    init_error = -1.0
    for _ in range(50):
        query = torch.rand(
            (batch_size, n_q, size)) * point_range * 2 - point_range
        key = torch.rand(
            (batch_size, n_k, size)) * point_range * 2 - point_range
        value = key
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((query - key)**2, dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        prediction, _ = mha.forward(query, key, value)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:
            assert error.item() < init_error
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.5
Exemple #22
0
    def sac_value_loss(
        self,
        log_probs: torch.Tensor,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
        discrete: bool,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _ent_coef = torch.exp(self._log_ent_coef)
            for name in values.keys():
                if not discrete:
                    min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
                else:
                    action_probs = log_probs.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * action_probs, self.act_size
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * action_probs, self.act_size
                    )
                    _q1p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q1p
                            ]
                        ),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q2p
                            ]
                        ),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if not discrete:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _ent_coef * log_probs, dim=1
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
                )
                value_losses.append(value_loss)
        else:
            branched_per_action_ent = ModelUtils.break_into_branches(
                log_probs * log_probs.exp(), self.act_size
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack(
                [
                    torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
                    for i, _lp in enumerate(branched_per_action_ent)
                ]
            )
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss
Exemple #23
0
    def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

        vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
        next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])]
        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        if self.policy.use_continuous_act:
            actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
        else:
            actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)

        memories_list = [
            ModelUtils.list_to_tensor(batch["memory"][i])
            for i in range(0, len(batch["memory"]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        offset = 1 if self.policy.sequence_length > 1 else 0
        next_memories_list = [
            ModelUtils.list_to_tensor(
                batch["memory"][i][self.policy.m_size // 2 :]
            )  # only pass value part of memory to target network
            for i in range(offset, len(batch["memory"]), self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            next_memories = torch.stack(next_memories_list).unsqueeze(0)
        else:
            memories = None
            next_memories = None
        # Q network memories are 0'ed out, since we don't have them during inference.
        q_memories = (
            torch.zeros_like(next_memories) if next_memories is not None else None
        )

        vis_obs: List[torch.Tensor] = []
        next_vis_obs: List[torch.Tensor] = []
        if self.policy.use_vis_obs:
            vis_obs = []
            for idx, _ in enumerate(
                self.policy.actor_critic.network_body.visual_processors
            ):
                vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
                vis_obs.append(vis_ob)
                next_vis_ob = ModelUtils.list_to_tensor(
                    batch["next_visual_obs%d" % idx]
                )
                next_vis_obs.append(next_vis_ob)

        # Copy normalizers from policy
        self.value_network.q1_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body
        )
        self.value_network.q2_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body
        )
        self.target_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body
        )
        (sampled_actions, _, log_probs, _, _) = self.policy.sample_actions(
            vec_obs,
            vis_obs,
            masks=act_masks,
            memories=memories,
            seq_len=self.policy.sequence_length,
            all_log_probs=not self.policy.use_continuous_act,
        )
        value_estimates, _ = self.policy.actor_critic.critic_pass(
            vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length
        )
        if self.policy.use_continuous_act:
            squeezed_actions = actions.squeeze(-1)
            # Only need grad for q1, as that is used for policy.
            q1p_out, q2p_out = self.value_network(
                vec_obs,
                vis_obs,
                sampled_actions,
                memories=q_memories,
                sequence_length=self.policy.sequence_length,
                q2_grad=False,
            )
            q1_out, q2_out = self.value_network(
                vec_obs,
                vis_obs,
                squeezed_actions,
                memories=q_memories,
                sequence_length=self.policy.sequence_length,
            )
            q1_stream, q2_stream = q1_out, q2_out
        else:
            # For discrete, you don't need to backprop through the Q for the policy
            q1p_out, q2p_out = self.value_network(
                vec_obs,
                vis_obs,
                memories=q_memories,
                sequence_length=self.policy.sequence_length,
                q1_grad=False,
                q2_grad=False,
            )
            q1_out, q2_out = self.value_network(
                vec_obs,
                vis_obs,
                memories=q_memories,
                sequence_length=self.policy.sequence_length,
            )
            q1_stream = self._condense_q_streams(q1_out, actions)
            q2_stream = self._condense_q_streams(q2_out, actions)

        with torch.no_grad():
            target_values, _ = self.target_network(
                next_vec_obs,
                next_vis_obs,
                memories=next_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
        use_discrete = not self.policy.use_continuous_act
        dones = ModelUtils.list_to_tensor(batch["done"])

        q1_loss, q2_loss = self.sac_q_loss(
            q1_stream, q2_stream, target_values, dones, rewards, masks
        )
        value_loss = self.sac_value_loss(
            log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete
        )
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete)
        entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)

        total_value_loss = q1_loss + q2_loss + value_loss

        decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(
            self.policy.actor_critic.critic, self.target_network, self.tau
        )
        update_stats = {
            "Losses/Policy Loss": policy_loss.item(),
            "Losses/Value Loss": value_loss.item(),
            "Losses/Q1 Loss": q1_loss.item(),
            "Losses/Q2 Loss": q2_loss.item(),
            "Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(),
            "Policy/Learning Rate": decay_lr,
        }

        return update_stats
Exemple #24
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(
                batch[RewardSignalUtil.rewards_key(name)])

        n_obs = len(self.policy.behavior_spec.observation_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        next_obs = ObsUtil.from_buffer_next(batch, n_obs)
        # Convert to tensors
        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

        act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
        actions = AgentAction.from_buffer(batch)

        memories_list = [
            ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in
            range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        value_memories_list = [
            ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
            for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]),
                           self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            value_memories = torch.stack(value_memories_list).unsqueeze(0)
        else:
            memories = None
            value_memories = None

        # Q and V network memories are 0'ed out, since we don't have them during inference.
        q_memories = (torch.zeros_like(value_memories)
                      if value_memories is not None else None)

        # Copy normalizers from policy
        self.q_network.q1_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self.q_network.q2_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self.target_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self._critic.network_body.copy_normalization(
            self.policy.actor.network_body)
        sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats(
            current_obs,
            masks=act_masks,
            memories=memories,
            sequence_length=self.policy.sequence_length,
        )
        value_estimates, _ = self._critic.critic_pass(
            current_obs,
            value_memories,
            sequence_length=self.policy.sequence_length)

        cont_sampled_actions = sampled_actions.continuous_tensor
        cont_actions = actions.continuous_tensor
        q1p_out, q2p_out = self.q_network(
            current_obs,
            cont_sampled_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
            q2_grad=False,
        )
        q1_out, q2_out = self.q_network(
            current_obs,
            cont_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
        )

        if self._action_spec.discrete_size > 0:
            disc_actions = actions.discrete_tensor
            q1_stream = self._condense_q_streams(q1_out, disc_actions)
            q2_stream = self._condense_q_streams(q2_out, disc_actions)
        else:
            q1_stream, q2_stream = q1_out, q2_out

        with torch.no_grad():
            # Since we didn't record the next value memories, evaluate one step in the critic to
            # get them.
            if value_memories is not None:
                # Get the first observation in each sequence
                just_first_obs = [
                    _obs[::self.policy.sequence_length] for _obs in current_obs
                ]
                _, next_value_memories = self._critic.critic_pass(
                    just_first_obs, value_memories, sequence_length=1)
            else:
                next_value_memories = None
            target_values, _ = self.target_network(
                next_obs,
                memories=next_value_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS],
                                          dtype=torch.bool)
        dones = ModelUtils.list_to_tensor(batch[BufferKey.DONE])

        q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values,
                                           dones, rewards, masks)
        value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out,
                                         q2p_out, masks)
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
        entropy_loss = self.sac_entropy_loss(log_probs, masks)

        total_value_loss = q1_loss + q2_loss
        if self.policy.shared_critic:
            policy_loss += value_loss
        else:
            total_value_loss += value_loss

        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(self._critic, self.target_network, self.tau)
        update_stats = {
            "Losses/Policy Loss":
            policy_loss.item(),
            "Losses/Value Loss":
            value_loss.item(),
            "Losses/Q1 Loss":
            q1_loss.item(),
            "Losses/Q2 Loss":
            q2_loss.item(),
            "Policy/Discrete Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.discrete)).item(),
            "Policy/Continuous Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.continuous)).item(),
            "Policy/Learning Rate":
            decay_lr,
        }

        return update_stats
Exemple #25
0
 def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
     with torch.no_grad():
         target = self._random_network(mini_batch)
         prediction = self._training_network(mini_batch)
         rewards = torch.sum((prediction - target)**2, dim=1)
     return rewards.detach().cpu().numpy()
    def get_trajectory_and_baseline_value_estimates(
        self,
        batch: AgentBuffer,
        next_obs: List[np.ndarray],
        next_groupmate_obs: List[List[np.ndarray]],
        done: bool,
        agent_id: str = "",
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, float],
               Optional[AgentBufferField], Optional[AgentBufferField], ]:
        """
        Get value estimates, baseline estimates, and memories for a trajectory, in batch form.
        :param batch: An AgentBuffer that consists of a trajectory.
        :param next_obs: the next observation (after the trajectory). Used for boostrapping
            if this is not a termiinal trajectory.
        :param next_groupmate_obs: the next observations from other members of the group.
        :param done: Set true if this is a terminal trajectory.
        :param agent_id: Agent ID of the agent that this trajectory belongs to.
        :returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)],
            the baseline estimates as a Dict, the final value estimate as a Dict of [name, float], and
            optionally (if using memories) an AgentBufferField of initial critic and baseline memories to be used
            during update.
        """

        n_obs = len(self.policy.behavior_spec.observation_specs)

        current_obs = ObsUtil.from_buffer(batch, n_obs)
        groupmate_obs = GroupObsUtil.from_buffer(batch, n_obs)

        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
        groupmate_obs = [[
            ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs
        ] for _groupmate_obs in groupmate_obs]

        groupmate_actions = AgentAction.group_from_buffer(batch)

        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
        next_obs = [obs.unsqueeze(0) for obs in next_obs]

        next_groupmate_obs = [
            ModelUtils.list_to_tensor_list(_list_obs)
            for _list_obs in next_groupmate_obs
        ]
        # Expand dimensions of next critic obs
        next_groupmate_obs = [[_obs.unsqueeze(0) for _obs in _list_obs]
                              for _list_obs in next_groupmate_obs]

        if agent_id in self.value_memory_dict:
            # The agent_id should always be in both since they are added together
            _init_value_mem = self.value_memory_dict[agent_id]
            _init_baseline_mem = self.baseline_memory_dict[agent_id]
        else:
            _init_value_mem = (torch.zeros((1, 1, self.critic.memory_size))
                               if self.policy.use_recurrent else None)
            _init_baseline_mem = (torch.zeros((1, 1, self.critic.memory_size))
                                  if self.policy.use_recurrent else None)

        all_obs = ([current_obs] + groupmate_obs
                   if groupmate_obs is not None else [current_obs])
        all_next_value_mem: Optional[AgentBufferField] = None
        all_next_baseline_mem: Optional[AgentBufferField] = None
        with torch.no_grad():
            if self.policy.use_recurrent:
                (
                    value_estimates,
                    baseline_estimates,
                    all_next_value_mem,
                    all_next_baseline_mem,
                    next_value_mem,
                    next_baseline_mem,
                ) = self._evaluate_by_sequence_team(
                    current_obs,
                    groupmate_obs,
                    groupmate_actions,
                    _init_value_mem,
                    _init_baseline_mem,
                )
            else:
                value_estimates, next_value_mem = self.critic.critic_pass(
                    all_obs,
                    _init_value_mem,
                    sequence_length=batch.num_experiences)
                groupmate_obs_and_actions = (groupmate_obs, groupmate_actions)
                baseline_estimates, next_baseline_mem = self.critic.baseline(
                    current_obs,
                    groupmate_obs_and_actions,
                    _init_baseline_mem,
                    sequence_length=batch.num_experiences,
                )
        # Store the memory for the next trajectory
        self.value_memory_dict[agent_id] = next_value_mem
        self.baseline_memory_dict[agent_id] = next_baseline_mem

        all_next_obs = ([next_obs] + next_groupmate_obs
                        if next_groupmate_obs is not None else [next_obs])

        next_value_estimates, _ = self.critic.critic_pass(all_next_obs,
                                                          next_value_mem,
                                                          sequence_length=1)

        for name, estimate in baseline_estimates.items():
            baseline_estimates[name] = ModelUtils.to_numpy(estimate)

        for name, estimate in value_estimates.items():
            value_estimates[name] = ModelUtils.to_numpy(estimate)

        # the base line and V shpuld  not be on the same done flag
        for name, estimate in next_value_estimates.items():
            next_value_estimates[name] = ModelUtils.to_numpy(estimate)

        if done:
            for k in next_value_estimates:
                if not self.reward_signals[k].ignore_done:
                    next_value_estimates[k][-1] = 0.0

        return (
            value_estimates,
            baseline_estimates,
            next_value_estimates,
            all_next_value_mem,
            all_next_baseline_mem,
        )
Exemple #27
0
    def get_trajectory_value_estimates(
        self,
        batch: AgentBuffer,
        next_obs: List[np.ndarray],
        done: bool,
        agent_id: str = "",
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]:
        """
        Get value estimates and memories for a trajectory, in batch form.
        :param batch: An AgentBuffer that consists of a trajectory.
        :param next_obs: the next observation (after the trajectory). Used for boostrapping
            if this is not a termiinal trajectory.
        :param done: Set true if this is a terminal trajectory.
        :param agent_id: Agent ID of the agent that this trajectory belongs to.
        :returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)],
            the final value estimate as a Dict of [name, float], and optionally (if using memories)
            an AgentBufferField of initial critic memories to be used during update.
        """
        n_obs = len(self.policy.behavior_spec.observation_specs)

        if agent_id in self.critic_memory_dict:
            memory = self.critic_memory_dict[agent_id]
        else:
            memory = (
                torch.zeros((1, 1, self.critic.memory_size))
                if self.policy.use_recurrent
                else None
            )

        # Convert to tensors
        current_obs = [
            ModelUtils.list_to_tensor(obs) for obs in ObsUtil.from_buffer(batch, n_obs)
        ]
        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

        next_obs = [obs.unsqueeze(0) for obs in next_obs]

        # If we're using LSTM, we want to get all the intermediate memories.
        all_next_memories: Optional[AgentBufferField] = None

        # To prevent memory leak and improve performance, evaluate with no_grad.
        with torch.no_grad():
            if self.policy.use_recurrent:
                (
                    value_estimates,
                    all_next_memories,
                    next_memory,
                ) = self._evaluate_by_sequence(current_obs, memory)
            else:
                value_estimates, next_memory = self.critic.critic_pass(
                    current_obs, memory, sequence_length=batch.num_experiences
                )

        # Store the memory for the next trajectory. This should NOT have a gradient.
        self.critic_memory_dict[agent_id] = next_memory

        next_value_estimate, _ = self.critic.critic_pass(
            next_obs, next_memory, sequence_length=1
        )

        for name, estimate in value_estimates.items():
            value_estimates[name] = ModelUtils.to_numpy(estimate)
            next_value_estimate[name] = ModelUtils.to_numpy(next_value_estimate[name])

        if done:
            for k in next_value_estimate:
                if not self.reward_signals[k].ignore_done:
                    next_value_estimate[k] = 0.0
            if agent_id in self.critic_memory_dict:
                self.critic_memory_dict.pop(agent_id)
        return value_estimates, next_value_estimate, all_next_memories
Exemple #28
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

        n_obs = len(self.policy.behavior_spec.sensor_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        next_obs = ObsUtil.from_buffer_next(batch, n_obs)
        # Convert to tensors
        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        actions = AgentAction.from_dict(batch)

        memories_list = [
            ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(
                0, len(batch["memory"]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        offset = 1 if self.policy.sequence_length > 1 else 0
        next_memories_list = [
            ModelUtils.list_to_tensor(
                batch["memory"][i]
                [self.policy.m_size //
                 2:])  # only pass value part of memory to target network
            for i in range(offset, len(batch["memory"]),
                           self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            next_memories = torch.stack(next_memories_list).unsqueeze(0)
        else:
            memories = None
            next_memories = None
        # Q network memories are 0'ed out, since we don't have them during inference.
        q_memories = (torch.zeros_like(next_memories)
                      if next_memories is not None else None)

        # Copy normalizers from policy
        self.value_network.q1_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        self.value_network.q2_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        self.target_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        (
            sampled_actions,
            log_probs,
            _,
            value_estimates,
            _,
        ) = self.policy.actor_critic.get_action_stats_and_value(
            current_obs,
            masks=act_masks,
            memories=memories,
            sequence_length=self.policy.sequence_length,
        )

        cont_sampled_actions = sampled_actions.continuous_tensor
        cont_actions = actions.continuous_tensor
        q1p_out, q2p_out = self.value_network(
            current_obs,
            cont_sampled_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
            q2_grad=False,
        )
        q1_out, q2_out = self.value_network(
            current_obs,
            cont_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
        )

        if self._action_spec.discrete_size > 0:
            disc_actions = actions.discrete_tensor
            q1_stream = self._condense_q_streams(q1_out, disc_actions)
            q2_stream = self._condense_q_streams(q2_out, disc_actions)
        else:
            q1_stream, q2_stream = q1_out, q2_out

        with torch.no_grad():
            target_values, _ = self.target_network(
                next_obs,
                memories=next_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
        dones = ModelUtils.list_to_tensor(batch["done"])

        q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values,
                                           dones, rewards, masks)
        value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out,
                                         q2p_out, masks)
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
        entropy_loss = self.sac_entropy_loss(log_probs, masks)

        total_value_loss = q1_loss + q2_loss + value_loss

        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(self.policy.actor_critic.critic,
                               self.target_network, self.tau)
        update_stats = {
            "Losses/Policy Loss":
            policy_loss.item(),
            "Losses/Value Loss":
            value_loss.item(),
            "Losses/Q1 Loss":
            q1_loss.item(),
            "Losses/Q2 Loss":
            q2_loss.item(),
            "Policy/Discrete Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.discrete)).item(),
            "Policy/Continuous Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.continuous)).item(),
            "Policy/Learning Rate":
            decay_lr,
        }

        return update_stats
Exemple #29
0
    def sac_value_loss(
        self,
        log_probs: ActionLogProbs,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _cont_ent_coef = self._log_ent_coef.continuous.exp()
            _disc_ent_coef = self._log_ent_coef.discrete.exp()
            for name in values.keys():
                if self._action_spec.discrete_size <= 0:
                    min_policy_qs[name] = torch.min(q1p_out[name],
                                                    q2p_out[name])
                else:
                    disc_action_probs = log_probs.all_discrete_tensor.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * disc_action_probs,
                        self._action_spec.discrete_branches,
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * disc_action_probs,
                        self._action_spec.discrete_branches,
                    )
                    _q1p_mean = torch.mean(
                        torch.stack([
                            torch.sum(_br, dim=1, keepdim=True)
                            for _br in _branched_q1p
                        ]),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack([
                            torch.sum(_br, dim=1, keepdim=True)
                            for _br in _branched_q2p
                        ]),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if self._action_spec.discrete_size <= 0:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _cont_ent_coef * log_probs.continuous_tensor, dim=1)
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup),
                    loss_masks)
                value_losses.append(value_loss)
        else:
            disc_log_probs = log_probs.all_discrete_tensor
            branched_per_action_ent = ModelUtils.break_into_branches(
                disc_log_probs * disc_log_probs.exp(),
                self._action_spec.discrete_branches,
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack([
                torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True)
                for i, _lp in enumerate(branched_per_action_ent)
            ])
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0)
                    # Add continuous entropy bonus to minimum Q
                    if self._action_spec.continuous_size > 0:
                        v_backup += torch.sum(
                            _cont_ent_coef * log_probs.continuous_tensor,
                            dim=1,
                            keepdim=True,
                        )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name],
                                                 v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss
 def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
     with torch.no_grad():
         rewards = ModelUtils.to_numpy(
             self._network.compute_reward(mini_batch))
     rewards = np.minimum(rewards, 1.0 / self.strength)
     return rewards * self._has_updated_once