Beispiel #1
0
                                         high=high[0])
                action_space = 2
            else:
                dist, action_size = ModelCatalog.get_action_dist(action,
                                                                 config=None)

            child_dist.append(dist)
            input_lens.append(action_size)

        return child_dist, input_lens

    def __init__(self, inputs, model):
        assert isinstance(model.action_space, gym.spaces.Tuple)

        child_dist, input_lens = self._get_child_dists(model.action_space)
        super().__init__(inputs,
                         model,
                         action_space=model.action_space,
                         child_distributions=child_dist,
                         input_lens=input_lens)

    @staticmethod
    def required_model_output_shape(action_space, model_config):
        child_dist, input_lens = Q1PhysActionDist._get_child_dists(
            action_space)
        return sum(input_lens)


ModelCatalog.register_custom_action_dist("q1_phys_action_dist",
                                         Q1PhysActionDist)