Example #1
0
 def __init__(
     self,
     observation_shapes: List[Tuple[int, ...]],
     network_settings: NetworkSettings,
     act_type: ActionType,
     act_size: List[int],
     conditional_sigma: bool = False,
     tanh_squash: bool = False,
 ):
     super().__init__()
     self.act_type = act_type
     self.act_size = act_size
     self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
     self.memory_size = torch.nn.Parameter(torch.Tensor([0]))
     self.is_continuous_int = torch.nn.Parameter(
         torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
     )
     self.act_size_vector = torch.nn.Parameter(torch.Tensor(act_size))
     self.network_body = NetworkBody(observation_shapes, network_settings)
     if network_settings.memory is not None:
         self.encoding_size = network_settings.memory.memory_size // 2
     else:
         self.encoding_size = network_settings.hidden_units
     if self.act_type == ActionType.CONTINUOUS:
         self.distribution = GaussianDistribution(
             self.encoding_size,
             act_size[0],
             conditional_sigma=conditional_sigma,
             tanh_squash=tanh_squash,
         )
     else:
         self.distribution = MultiCategoricalDistribution(
             self.encoding_size, act_size
         )
Example #2
0
def test_multi_categorical_distribution():
    torch.manual_seed(0)
    hidden_size = 16
    act_size = [3, 3, 4]
    sample_embedding = torch.ones((1, 16))
    gauss_dist = MultiCategoricalDistribution(hidden_size, act_size)

    # Make sure backprop works
    optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)

    def create_test_prob(size: int) -> torch.Tensor:
        test_prob = torch.tensor([[1.0 - 0.01 * (size - 1)] + [0.01] *
                                  (size - 1)])  # High prob for first action
        return test_prob.log()

    for _ in range(100):
        dist_insts = gauss_dist(sample_embedding,
                                masks=torch.ones((1, sum(act_size))))
        loss = 0
        for i, dist_inst in enumerate(dist_insts):
            assert isinstance(dist_inst, CategoricalDistInstance)
            log_prob = dist_inst.all_log_prob()
            test_log_prob = create_test_prob(act_size[i])
            # Force log_probs to match the high probability for the first action generated by
            # create_test_prob
            loss += torch.nn.functional.mse_loss(log_prob, test_log_prob)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    for dist_inst, size in zip(dist_insts, act_size):
        # Check that the log probs are close to the fake ones that we generated.
        test_log_probs = create_test_prob(size)
        for _prob, _test_prob in zip(
                dist_inst.all_log_prob().flatten().tolist(),
                test_log_probs.flatten().tolist(),
        ):
            assert _prob == pytest.approx(_test_prob, abs=0.1)

    # Test masks
    masks = []
    for branch in act_size:
        masks += [0] * (branch - 1) + [1]
    masks = torch.tensor([masks])
    dist_insts = gauss_dist(sample_embedding, masks=masks)
    for dist_inst in dist_insts:
        log_prob = dist_inst.all_log_prob()
        assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001)
Example #3
0
    def __init__(
        self,
        hidden_size: int,
        action_spec: ActionSpec,
        conditional_sigma: bool = False,
        tanh_squash: bool = False,
        deterministic: bool = False,
    ):
        """
        A torch module that represents the action space of a policy. The ActionModel may contain
        a continuous distribution, a discrete distribution or both where construction depends on
        the action_spec.  The ActionModel uses the encoded input of the network body to parameterize
        these distributions. The forward method of this module outputs the action, log probs,
        and entropies given the encoding from the network body.
        :params hidden_size: Size of the input to the ActionModel.
        :params action_spec: The ActionSpec defining the action space dimensions and distributions.
        :params conditional_sigma: Whether or not the std of a Gaussian is conditioned on state.
        :params tanh_squash: Whether to squash the output of a Gaussian with the tanh function.
        :params deterministic: Whether to select actions deterministically in policy.
        """
        super().__init__()
        self.encoding_size = hidden_size
        self.action_spec = action_spec
        self._continuous_distribution = None
        self._discrete_distribution = None

        if self.action_spec.continuous_size > 0:
            self._continuous_distribution = GaussianDistribution(
                self.encoding_size,
                self.action_spec.continuous_size,
                conditional_sigma=conditional_sigma,
                tanh_squash=tanh_squash,
            )

        if self.action_spec.discrete_size > 0:
            self._discrete_distribution = MultiCategoricalDistribution(
                self.encoding_size, self.action_spec.discrete_branches)

        # During training, clipping is done in TorchPolicy, but we need to clip before ONNX
        # export as well.
        self._clip_action_on_export = not tanh_squash
        self._deterministic = deterministic
Example #4
0
    def __init__(
        self,
        observation_shapes: List[Tuple[int, ...]],
        network_settings: NetworkSettings,
        action_spec: ActionSpec,
        conditional_sigma: bool = False,
        tanh_squash: bool = False,
    ):
        super().__init__()
        self.action_spec = action_spec
        self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
        self.is_continuous_int = torch.nn.Parameter(
            torch.Tensor([int(self.action_spec.is_continuous())]))
        self.act_size_vector = torch.nn.Parameter(
            torch.Tensor([
                self.action_spec.continuous_size +
                sum(self.action_spec.discrete_branches)
            ]),
            requires_grad=False,
        )
        self.network_body = NetworkBody(observation_shapes, network_settings)
        if network_settings.memory is not None:
            self.encoding_size = network_settings.memory.memory_size // 2
        else:
            self.encoding_size = network_settings.hidden_units

        if self.action_spec.is_continuous():
            self.distribution = GaussianDistribution(
                self.encoding_size,
                self.action_spec.continuous_size,
                conditional_sigma=conditional_sigma,
                tanh_squash=tanh_squash,
            )
        else:
            self.distribution = MultiCategoricalDistribution(
                self.encoding_size, self.action_spec.discrete_branches)
        # During training, clipping is done in TorchPolicy, but we need to clip before ONNX
        # export as well.
        self._clip_action_on_export = not tanh_squash