def shape2space(space: ObservationSpace) -> Space: """Convert an ObservationSpace description into a gym Space.""" def make_box(scalar_range_list, dtype, defaults): bounds = [scalar_range2tuple(r, defaults) for r in scalar_range_list] return Box( low=np.array([b[0] for b in bounds], dtype=dtype), high=np.array([b[1] for b in bounds], dtype=dtype), dtype=dtype, ) def make_seq(scalar_range, dtype, defaults): return Sequence( size_range=scalar_range2tuple(scalar_range, defaults), dtype=dtype, opaque_data_format=space.opaque_data_format, ) shape_type = space.WhichOneof("shape") if shape_type == "int64_range_list": return make_box( space.int64_range_list.range, np.int64, (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ) elif shape_type == "double_range_list": return make_box(space.double_range_list.range, np.float64, (-np.inf, np.inf)) elif shape_type == "string_size_range": return make_seq(space.string_size_range, str, (0, None)) elif shape_type == "binary_size_range": return make_seq(space.binary_size_range, bytes, (0, None)) else: raise TypeError(f"Cannot determine shape of ObservationSpace: {space}")
def from_proto(cls, index: int, proto: ObservationSpace): """Construct a space from an ObservationSpace message.""" shape_type = proto.WhichOneof("shape") def make_box(scalar_range_list, dtype, defaults): bounds = [ scalar_range2tuple(r, defaults) for r in scalar_range_list ] return Box( low=np.array([b[0] for b in bounds], dtype=dtype), high=np.array([b[1] for b in bounds], dtype=dtype), dtype=dtype, ) def make_seq(scalar_range, dtype, defaults): return Sequence( size_range=scalar_range2tuple(scalar_range, defaults), dtype=dtype, opaque_data_format=proto.opaque_data_format, ) # Translate from protocol buffer specification to python. There are # three variables to derive: 'space', the gym.Space instance describing # the space. 'cb' is a callback that translates from an Observation # message to a python type. and 'to_string' is a callback that # translates from a python type to a string for printing. if proto.opaque_data_format == "json://networkx/MultiDiGraph": # TODO(cummins): Add a Graph space. space = make_seq(proto.string_size_range, str, (0, None)) def cb(observation): return nx.readwrite.json_graph.node_link_graph(json.loads( observation.string_value), multigraph=True, directed=True) def to_string(observation): return json.dumps( nx.readwrite.json_graph.node_link_data(observation), indent=2) elif proto.opaque_data_format == "json://": space = make_seq(proto.string_size_range, str, (0, None)) def cb(observation): return json.loads(observation.string_value) def to_string(observation): return json.dumps(observation, indent=2) elif shape_type == "int64_range_list": space = make_box( proto.int64_range_list.range, np.int64, (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ) def cb(observation): return np.array(observation.int64_list.value, dtype=np.int64) to_string = str elif shape_type == "double_range_list": space = make_box(proto.double_range_list.range, np.float64, (-np.inf, np.inf)) def cb(observation): return np.array(observation.double_list.value, dtype=np.float64) to_string = str elif shape_type == "string_size_range": space = make_seq(proto.string_size_range, str, (0, None)) def cb(observation): return observation.string_value to_string = str elif shape_type == "binary_size_range": space = make_seq(proto.binary_size_range, bytes, (0, None)) def cb(observation): return observation.binary_value to_string = str else: raise TypeError( f"Cannot determine shape of ObservationSpace: {proto}") return cls( id=proto.name, index=index, space=space, cb=cb, to_string=to_string, deterministic=proto.deterministic, platform_dependent=proto.platform_dependent, default_value=cb(proto.default_value), )