Esempio n. 1
0
def shape2space(space: ObservationSpace) -> Space:
    """Convert an ObservationSpace description into a gym Space."""
    def make_box(scalar_range_list, dtype, defaults):
        bounds = [scalar_range2tuple(r, defaults) for r in scalar_range_list]
        return Box(
            low=np.array([b[0] for b in bounds], dtype=dtype),
            high=np.array([b[1] for b in bounds], dtype=dtype),
            dtype=dtype,
        )

    def make_seq(scalar_range, dtype, defaults):
        return Sequence(
            size_range=scalar_range2tuple(scalar_range, defaults),
            dtype=dtype,
            opaque_data_format=space.opaque_data_format,
        )

    shape_type = space.WhichOneof("shape")
    if shape_type == "int64_range_list":
        return make_box(
            space.int64_range_list.range,
            np.int64,
            (np.iinfo(np.int64).min, np.iinfo(np.int64).max),
        )
    elif shape_type == "double_range_list":
        return make_box(space.double_range_list.range, np.float64,
                        (-np.inf, np.inf))
    elif shape_type == "string_size_range":
        return make_seq(space.string_size_range, str, (0, None))
    elif shape_type == "binary_size_range":
        return make_seq(space.binary_size_range, bytes, (0, None))
    else:
        raise TypeError(f"Cannot determine shape of ObservationSpace: {space}")
    def from_proto(cls, index: int, proto: ObservationSpace):
        """Construct a space from an ObservationSpace message."""
        shape_type = proto.WhichOneof("shape")

        def make_box(scalar_range_list, dtype, defaults):
            bounds = [
                scalar_range2tuple(r, defaults) for r in scalar_range_list
            ]
            return Box(
                low=np.array([b[0] for b in bounds], dtype=dtype),
                high=np.array([b[1] for b in bounds], dtype=dtype),
                dtype=dtype,
            )

        def make_seq(scalar_range, dtype, defaults):
            return Sequence(
                size_range=scalar_range2tuple(scalar_range, defaults),
                dtype=dtype,
                opaque_data_format=proto.opaque_data_format,
            )

        # Translate from protocol buffer specification to python. There are
        # three variables to derive: 'space', the gym.Space instance describing
        # the space. 'cb' is a callback that translates from an Observation
        # message to a python type. and 'to_string' is a callback that
        # translates from a python type to a string for printing.
        if proto.opaque_data_format == "json://networkx/MultiDiGraph":
            # TODO(cummins): Add a Graph space.
            space = make_seq(proto.string_size_range, str, (0, None))

            def cb(observation):
                return nx.readwrite.json_graph.node_link_graph(json.loads(
                    observation.string_value),
                                                               multigraph=True,
                                                               directed=True)

            def to_string(observation):
                return json.dumps(
                    nx.readwrite.json_graph.node_link_data(observation),
                    indent=2)

        elif proto.opaque_data_format == "json://":
            space = make_seq(proto.string_size_range, str, (0, None))

            def cb(observation):
                return json.loads(observation.string_value)

            def to_string(observation):
                return json.dumps(observation, indent=2)

        elif shape_type == "int64_range_list":
            space = make_box(
                proto.int64_range_list.range,
                np.int64,
                (np.iinfo(np.int64).min, np.iinfo(np.int64).max),
            )

            def cb(observation):
                return np.array(observation.int64_list.value, dtype=np.int64)

            to_string = str
        elif shape_type == "double_range_list":
            space = make_box(proto.double_range_list.range, np.float64,
                             (-np.inf, np.inf))

            def cb(observation):
                return np.array(observation.double_list.value,
                                dtype=np.float64)

            to_string = str
        elif shape_type == "string_size_range":
            space = make_seq(proto.string_size_range, str, (0, None))

            def cb(observation):
                return observation.string_value

            to_string = str
        elif shape_type == "binary_size_range":
            space = make_seq(proto.binary_size_range, bytes, (0, None))

            def cb(observation):
                return observation.binary_value

            to_string = str
        else:
            raise TypeError(
                f"Cannot determine shape of ObservationSpace: {proto}")

        return cls(
            id=proto.name,
            index=index,
            space=space,
            cb=cb,
            to_string=to_string,
            deterministic=proto.deterministic,
            platform_dependent=proto.platform_dependent,
            default_value=cb(proto.default_value),
        )