def _create_data(self, dataContainerPb): if (dataContainerPb.type == pb.Discrete): discreteContainerPb = pb.DiscreteDataContainer() dataContainerPb.data.Unpack(discreteContainerPb) data = discreteContainerPb.data return data if (dataContainerPb.type == pb.Box): boxContainerPb = pb.BoxDataContainer() dataContainerPb.data.Unpack(boxContainerPb) # print(boxContainerPb.shape, boxContainerPb.dtype, boxContainerPb.uintData) if boxContainerPb.dtype == pb.INT: data = boxContainerPb.intData elif boxContainerPb.dtype == pb.UINT: data = boxContainerPb.uintData elif boxContainerPb.dtype == pb.DOUBLE: data = boxContainerPb.doubleData else: data = boxContainerPb.floatData # TODO: reshape using shape info return data elif (dataContainerPb.type == pb.Tuple): tupleDataPb = pb.TupleDataContainer() dataContainerPb.data.Unpack(tupleDataPb) myDataList = [] for pbSubData in tupleDataPb.element: subData = self._create_data(pbSubData) myDataList.append(subData) data = tuple(myDataList) return data elif (dataContainerPb.type == pb.Dict): dictDataPb = pb.DictDataContainer() dataContainerPb.data.Unpack(dictDataPb) myDataDict = {} for pbSubData in dictDataPb.element: subData = self._create_data(pbSubData) myDataDict[pbSubData.name] = subData data = myDataDict return data
def _pack_data(self, actions, spaceDesc): dataContainer = pb.DataContainer() spaceType = spaceDesc.__class__ if spaceType == spaces.Discrete: dataContainer.type = pb.Discrete discreteContainerPb = pb.DiscreteDataContainer() discreteContainerPb.data = actions dataContainer.data.Pack(discreteContainerPb) elif spaceType == spaces.Box: dataContainer.type = pb.Box boxContainerPb = pb.BoxDataContainer() shape = [len(actions)] boxContainerPb.shape.extend(shape) if (spaceDesc.dtype in ['int', 'int8', 'int16', 'int32', 'int64']): boxContainerPb.dtype = pb.INT boxContainerPb.intData.extend(actions) elif (spaceDesc.dtype in ['uint', 'uint8', 'uint16', 'uint32', 'uint64']): boxContainerPb.dtype = pb.UINT boxContainerPb.uintData.extend(actions) elif (spaceDesc.dtype in ['float', 'float32', 'float64']): boxContainerPb.dtype = pb.FLOAT boxContainerPb.floatData.extend(actions) elif (spaceDesc.dtype in ['double']): boxContainerPb.dtype = pb.DOUBLE boxContainerPb.doubleData.extend(actions) else: boxContainerPb.dtype = pb.FLOAT boxContainerPb.floatData.extend(actions) dataContainer.data.Pack(boxContainerPb) elif spaceType == spaces.Tuple: dataContainer.type = pb.Tuple tupleDataPb = pb.TupleDataContainer() spaceList = list(self._action_space.spaces) subDataList = [] for subAction, subActSpaceType in zip(actions, spaceList): subData = self._pack_data(subAction, subActSpaceType) subDataList.append(subData) tupleDataPb.element.extend(subDataList) dataContainer.data.Pack(tupleDataPb) elif spaceType == spaces.Dict: dataContainer.type = pb.Dict dictDataPb = pb.DictDataContainer() subDataList = [] for sName, subAction in actions.items(): subActSpaceType = self._action_space.spaces[sName] subData = self._pack_data(subAction, subActSpaceType) subData.name = sName subDataList.append(subData) dictDataPb.element.extend(subDataList) dataContainer.data.Pack(dictDataPb) return dataContainer