Esempio n. 1
0
 def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Creates the action Tensor. In continuous case, corresponds to the action. In
     the discrete case, corresponds to the concatenation of one hot action Tensors.
     """
     return self._action_flattener.forward(
         AgentAction.from_dict(mini_batch))
    def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Uses the current state embedding and the action of the mini_batch to predict
        the next state embedding.
        """
        actions = AgentAction.from_dict(mini_batch)
        flattened_action = self._action_flattener.forward(actions)
        forward_model_input = torch.cat(
            (self.get_current_state(mini_batch), flattened_action), dim=1)

        return self.forward_model_next_state_prediction(forward_model_input)
Esempio n. 3
0
    def _update_batch(self, mini_batch_demo: Dict[str, np.ndarray],
                      n_sequences: int) -> Dict[str, float]:
        """
        Helper function for update_batch.
        """
        vec_obs = [ModelUtils.list_to_tensor(mini_batch_demo["vector_obs"])]
        act_masks = None
        expert_actions = AgentAction.from_dict(mini_batch_demo)
        if self.policy.behavior_spec.action_spec.discrete_size > 0:

            act_masks = ModelUtils.list_to_tensor(
                np.ones(
                    (
                        self.n_sequences * self.policy.sequence_length,
                        sum(self.policy.behavior_spec.action_spec.
                            discrete_branches),
                    ),
                    dtype=np.float32,
                ))

        memories = []
        if self.policy.use_recurrent:
            memories = torch.zeros(1, self.n_sequences, self.policy.m_size)

        if self.policy.use_vis_obs:
            vis_obs = []
            for idx, _ in enumerate(
                    self.policy.actor_critic.network_body.visual_processors):
                vis_ob = ModelUtils.list_to_tensor(
                    mini_batch_demo["visual_obs%d" % idx])
                vis_obs.append(vis_ob)
        else:
            vis_obs = []

        selected_actions, log_probs, _, _ = self.policy.sample_actions(
            vec_obs,
            vis_obs,
            masks=act_masks,
            memories=memories,
            seq_len=self.policy.sequence_length,
        )
        bc_loss = self._behavioral_cloning_loss(selected_actions, log_probs,
                                                expert_actions)
        self.optimizer.zero_grad()
        bc_loss.backward()

        self.optimizer.step()
        run_out = {"loss": bc_loss.item()}
        return run_out
Esempio n. 4
0
    def _update_batch(
        self, mini_batch_demo: AgentBuffer, n_sequences: int
    ) -> Dict[str, float]:
        """
        Helper function for update_batch.
        """
        np_obs = ObsUtil.from_buffer(
            mini_batch_demo, len(self.policy.behavior_spec.sensor_specs)
        )
        # Convert to tensors
        tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]
        act_masks = None
        expert_actions = AgentAction.from_dict(mini_batch_demo)
        if self.policy.behavior_spec.action_spec.discrete_size > 0:

            act_masks = ModelUtils.list_to_tensor(
                np.ones(
                    (
                        self.n_sequences * self.policy.sequence_length,
                        sum(self.policy.behavior_spec.action_spec.discrete_branches),
                    ),
                    dtype=np.float32,
                )
            )

        memories = []
        if self.policy.use_recurrent:
            memories = torch.zeros(1, self.n_sequences, self.policy.m_size)

        selected_actions, log_probs, _, _ = self.policy.sample_actions(
            tensor_obs,
            masks=act_masks,
            memories=memories,
            seq_len=self.policy.sequence_length,
        )
        bc_loss = self._behavioral_cloning_loss(
            selected_actions, log_probs, expert_actions
        )
        self.optimizer.zero_grad()
        bc_loss.backward()

        self.optimizer.step()
        run_out = {"loss": bc_loss.item()}
        return run_out
 def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Computes the inverse loss for a mini_batch. Corresponds to the error on the
     action prediction (given the current and next state).
     """
     predicted_action = self.predict_action(mini_batch)
     actions = AgentAction.from_dict(mini_batch)
     _inverse_loss = 0
     if self._action_spec.continuous_size > 0:
         sq_difference = (
             actions.continuous_tensor - predicted_action.continuous
         ) ** 2
         sq_difference = torch.sum(sq_difference, dim=1)
         _inverse_loss += torch.mean(
             ModelUtils.dynamic_partition(
                 sq_difference,
                 ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float),
                 2,
             )[1]
         )
     if self._action_spec.discrete_size > 0:
         true_action = torch.cat(
             ModelUtils.actions_to_onehot(
                 actions.discrete_tensor, self._action_spec.discrete_branches
             ),
             dim=1,
         )
         cross_entropy = torch.sum(
             -torch.log(predicted_action.discrete + self.EPSILON) * true_action,
             dim=1,
         )
         _inverse_loss += torch.mean(
             ModelUtils.dynamic_partition(
                 cross_entropy,
                 ModelUtils.list_to_tensor(
                     mini_batch["masks"], dtype=torch.float
                 ),  # use masks not action_masks
                 2,
             )[1]
         )
     return _inverse_loss
def test_evaluate_actions(rnn, visual, discrete):
    policy = create_policy_mock(TrainerSettings(),
                                use_rnn=rnn,
                                use_discrete=discrete,
                                use_visual=visual)
    buffer = mb.simulate_rollout(64,
                                 policy.behavior_spec,
                                 memory_size=policy.m_size)
    vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])]
    act_masks = ModelUtils.list_to_tensor(buffer["action_mask"])
    agent_action = AgentAction.from_dict(buffer)
    vis_obs = []
    for idx, _ in enumerate(
            policy.actor_critic.network_body.visual_processors):
        vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx])
        vis_obs.append(vis_ob)

    memories = [
        ModelUtils.list_to_tensor(buffer["memory"][i])
        for i in range(0, len(buffer["memory"]), policy.sequence_length)
    ]
    if len(memories) > 0:
        memories = torch.stack(memories).unsqueeze(0)

    log_probs, entropy, values = policy.evaluate_actions(
        vec_obs,
        vis_obs,
        masks=act_masks,
        actions=agent_action,
        memories=memories,
        seq_len=policy.sequence_length,
    )
    if discrete:
        _size = policy.behavior_spec.action_spec.discrete_size
    else:
        _size = policy.behavior_spec.action_spec.continuous_size

    assert log_probs.flatten().shape == (64, _size)
    assert entropy.shape == (64, )
    for val in values.values():
        assert val.shape == (64, )
Esempio n. 7
0
def test_evaluate_actions(rnn, visual, discrete):
    policy = create_policy_mock(TrainerSettings(),
                                use_rnn=rnn,
                                use_discrete=discrete,
                                use_visual=visual)
    buffer = mb.simulate_rollout(64,
                                 policy.behavior_spec,
                                 memory_size=policy.m_size)
    act_masks = ModelUtils.list_to_tensor(buffer["action_mask"])
    agent_action = AgentAction.from_dict(buffer)
    np_obs = ObsUtil.from_buffer(buffer,
                                 len(policy.behavior_spec.observation_specs))
    tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]

    memories = [
        ModelUtils.list_to_tensor(buffer["memory"][i])
        for i in range(0, len(buffer["memory"]), policy.sequence_length)
    ]
    if len(memories) > 0:
        memories = torch.stack(memories).unsqueeze(0)

    log_probs, entropy, values = policy.evaluate_actions(
        tensor_obs,
        masks=act_masks,
        actions=agent_action,
        memories=memories,
        seq_len=policy.sequence_length,
    )
    if discrete:
        _size = policy.behavior_spec.action_spec.discrete_size
    else:
        _size = policy.behavior_spec.action_spec.continuous_size

    assert log_probs.flatten().shape == (64, _size)
    assert entropy.shape == (64, )
    for val in values.values():
        assert val.shape == (64, )
Esempio n. 8
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

        n_obs = len(self.policy.behavior_spec.sensor_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        next_obs = ObsUtil.from_buffer_next(batch, n_obs)
        # Convert to tensors
        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        actions = AgentAction.from_dict(batch)

        memories_list = [
            ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(
                0, len(batch["memory"]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        offset = 1 if self.policy.sequence_length > 1 else 0
        next_memories_list = [
            ModelUtils.list_to_tensor(
                batch["memory"][i]
                [self.policy.m_size //
                 2:])  # only pass value part of memory to target network
            for i in range(offset, len(batch["memory"]),
                           self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            next_memories = torch.stack(next_memories_list).unsqueeze(0)
        else:
            memories = None
            next_memories = None
        # Q network memories are 0'ed out, since we don't have them during inference.
        q_memories = (torch.zeros_like(next_memories)
                      if next_memories is not None else None)

        # Copy normalizers from policy
        self.value_network.q1_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        self.value_network.q2_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        self.target_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        (
            sampled_actions,
            log_probs,
            _,
            value_estimates,
            _,
        ) = self.policy.actor_critic.get_action_stats_and_value(
            current_obs,
            masks=act_masks,
            memories=memories,
            sequence_length=self.policy.sequence_length,
        )

        cont_sampled_actions = sampled_actions.continuous_tensor
        cont_actions = actions.continuous_tensor
        q1p_out, q2p_out = self.value_network(
            current_obs,
            cont_sampled_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
            q2_grad=False,
        )
        q1_out, q2_out = self.value_network(
            current_obs,
            cont_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
        )

        if self._action_spec.discrete_size > 0:
            disc_actions = actions.discrete_tensor
            q1_stream = self._condense_q_streams(q1_out, disc_actions)
            q2_stream = self._condense_q_streams(q2_out, disc_actions)
        else:
            q1_stream, q2_stream = q1_out, q2_out

        with torch.no_grad():
            target_values, _ = self.target_network(
                next_obs,
                memories=next_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
        dones = ModelUtils.list_to_tensor(batch["done"])

        q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values,
                                           dones, rewards, masks)
        value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out,
                                         q2p_out, masks)
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
        entropy_loss = self.sac_entropy_loss(log_probs, masks)

        total_value_loss = q1_loss + q2_loss + value_loss

        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(self.policy.actor_critic.critic,
                               self.target_network, self.tau)
        update_stats = {
            "Losses/Policy Loss":
            policy_loss.item(),
            "Losses/Value Loss":
            value_loss.item(),
            "Losses/Q1 Loss":
            q1_loss.item(),
            "Losses/Q2 Loss":
            q2_loss.item(),
            "Policy/Discrete Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.discrete)).item(),
            "Policy/Continuous Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.continuous)).item(),
            "Policy/Learning Rate":
            decay_lr,
        }

        return update_stats
Esempio n. 9
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Performs update on model.
        :param batch: Batch of experiences.
        :param num_sequences: Number of sequences to process.
        :return: Results of update.
        """
        # Get decayed parameters
        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        decay_eps = self.decay_epsilon.get_value(
            self.policy.get_current_step())
        decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
        returns = {}
        old_values = {}
        for name in self.reward_signals:
            old_values[name] = ModelUtils.list_to_tensor(
                batch[f"{name}_value_estimates"])
            returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"])

        vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        actions = AgentAction.from_dict(batch)

        memories = [
            ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(
                0, len(batch["memory"]), self.policy.sequence_length)
        ]
        if len(memories) > 0:
            memories = torch.stack(memories).unsqueeze(0)

        if self.policy.use_vis_obs:
            vis_obs = []
            for idx, _ in enumerate(
                    self.policy.actor_critic.network_body.visual_processors):
                vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
                vis_obs.append(vis_ob)
        else:
            vis_obs = []

        log_probs, entropy, values = self.policy.evaluate_actions(
            vec_obs,
            vis_obs,
            masks=act_masks,
            actions=actions,
            memories=memories,
            seq_len=self.policy.sequence_length,
        )
        old_log_probs = ActionLogProbs.from_dict(batch).flatten()
        log_probs = log_probs.flatten()
        loss_masks = ModelUtils.list_to_tensor(batch["masks"],
                                               dtype=torch.bool)
        value_loss = self.ppo_value_loss(values, old_values, returns,
                                         decay_eps, loss_masks)
        policy_loss = self.ppo_policy_loss(
            ModelUtils.list_to_tensor(batch["advantages"]),
            log_probs,
            old_log_probs,
            loss_masks,
        )
        loss = (policy_loss + 0.5 * value_loss -
                decay_bet * ModelUtils.masked_mean(entropy, loss_masks))

        # Set optimizer learning rate
        ModelUtils.update_learning_rate(self.optimizer, decay_lr)
        self.optimizer.zero_grad()
        loss.backward()

        self.optimizer.step()
        update_stats = {
            # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow.
            # TODO: After PyTorch is default, change to something more correct.
            "Losses/Policy Loss": torch.abs(policy_loss).item(),
            "Losses/Value Loss": value_loss.item(),
            "Policy/Learning Rate": decay_lr,
            "Policy/Epsilon": decay_eps,
            "Policy/Beta": decay_bet,
        }

        for reward_provider in self.reward_signals.values():
            update_stats.update(reward_provider.update(batch))

        return update_stats