def try_space_inference_from_list(list_op, dtype=None, **low_high): """ Attempts to infer shape space from a list op. A list op may be the result of fetching state from a Python memory. Args: list_op (list): List with arbitrary sub-structure. Returns: Space: Inferred Space object represented by list. """ shape = len(list_op) if shape > 0: # Try to infer more things by looking inside list. elem = list_op[0] if (get_backend() == "pytorch" and isinstance(elem, torch.Tensor)) or \ get_backend() == "tf" and isinstance(elem, tf.Tensor): list_type = dtype or elem.dtype inner_shape = elem.shape return BoxSpace.from_spec(spec=convert_dtype(list_type, "np"), shape=(shape, ) + inner_shape, add_batch_rank=True, **low_high) elif isinstance(elem, list): inner_shape = len(elem) return BoxSpace.from_spec(spec=convert_dtype(dtype or float, "np"), shape=(shape, inner_shape), add_batch_rank=True, **low_high) # IntBox -> elem must be int and dtype hint must match (or None). elif isinstance(elem, int) and (dtype is None or dtype == "int"): # In case of missing comma values, check all other items in list for float. # If one float in there -> FloatBox, otherwise -> IntBox. has_floats = any(isinstance(el, float) for el in list_op) if has_floats is False: return IntBox.from_spec(shape=(shape, ), add_batch_rank=True, **low_high) else: return FloatBox.from_spec(shape=(shape, ), add_batch_rank=True, **low_high) # FloatBox -> elem must be float (or int) and dtype hint must match (or None). elif isinstance(elem, (float, int)) and (dtype is None or dtype == "float"): return FloatBox.from_spec(shape=(shape, ), add_batch_rank=True, **low_high) # Most general guess is a Float box. return FloatBox(shape=(shape, ), **low_high)
def try_space_inference_from_list(list_op): """ Attempts to infer shape space from a list op. A list op may be the result of fetching state from a Python memory. Args: list_op (list): List with arbitrary sub-structure. Returns: Space: Inferred Space object represented by list. """ if get_backend() == "pytorch": batch_shape = len(list_op) if batch_shape > 0: # Try to infer more things by looking inside list. elem = list_op[0] if isinstance(elem, torch.Tensor): list_type = elem.dtype inner_shape = elem.shape return BoxSpace.from_spec(spec=convert_dtype(list_type, "np"), shape=(batch_shape, ) + inner_shape, add_batch_rank=True) elif isinstance(elem, list): inner_shape = len(elem) return BoxSpace.from_spec(spec=convert_dtype(float, "np"), shape=(batch_shape, inner_shape), add_batch_rank=True) else: # Most general guess is a Float box. return FloatBox(shape=(batch_shape, )) else: raise ValueError( "List inference should only be attempted on the Python backend.")
def get_preprocessed_space(self, space): # Translate to corresponding FloatBoxes. ret = dict() for key, value in space.flatten().items(): ret[key] = FloatBox(shape=value.shape, add_batch_rank=value.has_batch_rank, add_time_rank=value.has_time_rank) return unflatten_op(ret)
def as_one_hot_float_space(self): """ Returns a new FloatBox Space resulting from one-hot flattening out this space along its number of categories. Returns: FloatBox: The resulting FloatBox Space (with the same batch and time-rank settings). """ return FloatBox( low=0.0, high=1.0, shape=self.get_shape(with_category_rank=True), add_batch_rank=self.has_batch_rank, add_time_rank=self.has_time_rank )
def get_space_from_op(op): """ Tries to re-create a Space object given some DataOp (e.g. a tf op). This is useful for shape inference on returned ops after having run through a graph_fn. Args: op (DataOp): The op to create a corresponding Space for. Returns: Space: The inferred Space object. """ # a Dict if isinstance(op, dict): # DataOpDict spec = {} add_batch_rank = False add_time_rank = False for key, value in op.items(): spec[key] = get_space_from_op(value) if spec[key].has_batch_rank: add_batch_rank = True if spec[key].has_time_rank: add_time_rank = True return Dict(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # a Tuple elif isinstance(op, tuple): # DataOpTuple spec = [] add_batch_rank = False add_time_rank = False for i in op: space = get_space_from_op(i) if space == 0: return 0 spec.append(space) if spec[-1].has_batch_rank: add_batch_rank = True if spec[-1].has_time_rank: add_time_rank = True return Tuple(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # primitive Space -> infer from op dtype and shape else: # Op itself is a single value, simple python type. if isinstance(op, (bool, int, float)): return BoxSpace.from_spec(spec=type(op), shape=()) elif isinstance(op, str): raise RLGraphError( "Cannot derive Space from non-allowed op ({})!".format(op)) # A single numpy array. elif isinstance(op, np.ndarray): return BoxSpace.from_spec(spec=convert_dtype(str(op.dtype), "np"), shape=op.shape) elif isinstance(op, list): return try_space_inference_from_list(op) # No Space: e.g. the tf.no_op, a distribution (anything that's not a tensor). # PyTorch Tensors do not have get_shape so must check backend. elif hasattr(op, "dtype") is False or (get_backend() == "tf" and not hasattr(op, "get_shape")): return 0 # Some tensor: can be converted into a BoxSpace. else: shape = get_shape(op) # Unknown shape (e.g. a cond op). if shape is None: return 0 add_batch_rank = False add_time_rank = False time_major = False new_shape = list(shape) # New way: Detect via op._batch_rank and op._time_rank properties where these ranks are. if hasattr(op, "_batch_rank") and isinstance(op._batch_rank, int): add_batch_rank = True new_shape[op._batch_rank] = -1 # elif get_backend() == "pytorch": # if isinstance(op, torch.Tensor): # if op.dim() > 1 and shape[0] == 1: # add_batch_rank = True # new_shape[0] = 1 if hasattr(op, "_time_rank") and isinstance(op._time_rank, int): add_time_rank = True if op._time_rank == 0: time_major = True new_shape[op._time_rank] = -1 shape = tuple(n for n in new_shape if n != -1) # Old way: Detect automatically whether the first rank(s) are batch and/or time rank. if add_batch_rank is False and add_time_rank is False and shape != ( ) and shape[0] is None: if len(shape) > 1 and shape[1] is None: #raise RLGraphError( # "ERROR: Cannot determine time-major flag if both batch- and time-ranks are in an op w/o saying " # "which rank goes to which position!" #) shape = shape[2:] add_time_rank = True else: shape = shape[1:] add_batch_rank = True # TODO: If op._batch_rank and/or op._time_rank are not set, set them now. base_dtype = op.dtype.base_dtype if hasattr( op.dtype, "base_dtype") else op.dtype # PyTorch does not have a bool type if get_backend() == "pytorch": if op.dtype is torch.uint8: base_dtype = bool base_dtype_str = str(base_dtype) # FloatBox if "float" in base_dtype_str: return FloatBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=convert_dtype(base_dtype, "np")) # IntBox elif "int" in base_dtype_str: high = getattr(op, "_num_categories", None) return IntBox(high, shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=convert_dtype(base_dtype, "np")) # a BoolBox elif "bool" in base_dtype_str: return BoolBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) # a TextBox elif "string" in base_dtype_str: return TextBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) raise RLGraphError( "ERROR: Cannot derive Space from op '{}' (unknown type?)!".format(op))
def get_space_from_op(op, read_key_hints=False, dtype=None, low=None, high=None): """ Tries to re-create a Space object given some DataOp (e.g. a tf op). This is useful for shape inference on returned ops after having run through a graph_fn. Args: op (DataOp): The op to create a corresponding Space for. read_key_hints (bool): If True, tries to read type- and low/high-hints from the pattern of the Dict keys (str). - Preceding "I_": IntBox, "F_": FloatBox, "B_": BoolBox. - Succeeding "_low=0.0": Low value. - Succeeding "_high=1.0": High value. E.g. Dict key "F_somekey_low=0.0_high=2.0" indicates a FloatBox with low=0.0 and high=2.0. Dict key "I_somekey" indicates an intbox with no limits. Dict key "I_somekey_high=5" indicates an intbox with high=5 (values 0-4). Default: False. dtype (Optional[str]): An optional indicator, what the `dtype` of a BoxSpace should be. low (Optional[int,float]): An optional indicator, what the `low` property for a BoxSpace should be. high (Optional[int,float]): An optional indicator, what the `high` property for a BoxSpace should be. Returns: Space: The inferred Space object. """ # a Dict if isinstance(op, dict): # DataOpDict spec = {} add_batch_rank = False add_time_rank = False for key, value in op.items(): # Try to infer hints from the key. if read_key_hints is True: dtype, low, high = get_space_hints_from_dict_key(key) spec[key] = get_space_from_op(value, dtype=dtype, low=low, high=high) # Return if spec[key] == 0: return 0 if spec[key].has_batch_rank: add_batch_rank = True if spec[key].has_time_rank: add_time_rank = True return Dict(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # a Tuple elif isinstance(op, tuple): # DataOpTuple spec = [] add_batch_rank = False add_time_rank = False for i in op: space = get_space_from_op(i) if space == 0: return 0 spec.append(space) if spec[-1].has_batch_rank: add_batch_rank = True if spec[-1].has_time_rank: add_time_rank = True return Tuple(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # primitive Space -> infer from op dtype and shape else: low_high = {} if high is not None: low_high["high"] = high if low is not None: low_high["low"] = low # Op itself is a single value, simple python type. if isinstance(op, (bool, int, float)): return BoxSpace.from_spec(spec=(dtype or type(op)), shape=(), **low_high) elif isinstance(op, str): raise RLGraphError( "Cannot derive Space from non-allowed op ({})!".format(op)) # A single numpy array. elif isinstance(op, np.ndarray): return BoxSpace.from_spec(spec=convert_dtype(str(op.dtype), "np"), shape=op.shape, **low_high) elif isinstance(op, list): return try_space_inference_from_list(op, dtype=dtype, **low_high) # No Space: e.g. the tf.no_op, a distribution (anything that's not a tensor). # PyTorch Tensors do not have get_shape so must check backend. elif hasattr(op, "dtype") is False or (get_backend() == "tf" and not hasattr(op, "get_shape")): return 0 # Some tensor: can be converted into a BoxSpace. else: shape = get_shape(op) # Unknown shape (e.g. a cond op). if shape is None: return 0 add_batch_rank = False add_time_rank = False time_major = False new_shape = list(shape) # New way: Detect via op._batch_rank and op._time_rank properties where these ranks are. if hasattr(op, "_batch_rank") and isinstance(op._batch_rank, int): add_batch_rank = True new_shape[op._batch_rank] = -1 # elif get_backend() == "pytorch": # if isinstance(op, torch.Tensor): # if op.dim() > 1 and shape[0] == 1: # add_batch_rank = True # new_shape[0] = 1 if hasattr(op, "_time_rank") and isinstance(op._time_rank, int): add_time_rank = True if op._time_rank == 0: time_major = True new_shape[op._time_rank] = -1 shape = tuple(n for n in new_shape if n != -1) # Old way: Detect automatically whether the first rank(s) are batch and/or time rank. if add_batch_rank is False and add_time_rank is False and shape != ( ) and shape[0] is None: if len(shape) > 1 and shape[1] is None: #raise RLGraphError( # "ERROR: Cannot determine time-major flag if both batch- and time-ranks are in an op w/o saying " # "which rank goes to which position!" #) shape = shape[2:] add_time_rank = True else: shape = shape[1:] add_batch_rank = True # TODO: If op._batch_rank and/or op._time_rank are not set, set them now. base_dtype = op.dtype.base_dtype if hasattr( op.dtype, "base_dtype") else op.dtype # PyTorch does not have a bool type if get_backend() == "pytorch": if op.dtype is torch.uint8: base_dtype = bool base_dtype_str = str(base_dtype) # FloatBox if "float" in base_dtype_str: return FloatBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=convert_dtype(base_dtype, "np")) # IntBox elif "int" in base_dtype_str: high_ = high or getattr(op, "_num_categories", None) return IntBox(high_, shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=convert_dtype(base_dtype, "np")) # a BoolBox elif "bool" in base_dtype_str: return BoolBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) # a TextBox elif "string" in base_dtype_str: return TextBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) raise RLGraphError( "ERROR: Cannot derive Space from op '{}' (unknown type?)!".format(op))