def get_space_from_op(op): """ Tries to re-create a Space object given some DataOp (e.g. a tf op). This is useful for shape inference on returned ops after having run through a graph_fn. Args: op (DataOp): The op to create a corresponding Space for. Returns: Space: The inferred Space object. """ # a Dict if isinstance(op, dict): # DataOpDict spec = {} add_batch_rank = False add_time_rank = False for key, value in op.items(): spec[key] = get_space_from_op(value) if spec[key].has_batch_rank: add_batch_rank = True if spec[key].has_time_rank: add_time_rank = True return Dict(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # a Tuple elif isinstance(op, tuple): # DataOpTuple spec = [] add_batch_rank = False add_time_rank = False for i in op: space = get_space_from_op(i) if space == 0: return 0 spec.append(space) if spec[-1].has_batch_rank: add_batch_rank = True if spec[-1].has_time_rank: add_time_rank = True return Tuple(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # primitive Space -> infer from op dtype and shape else: # Op itself is a single value, simple python type. if isinstance(op, (bool, int, float)): return BoxSpace.from_spec(spec=type(op), shape=()) elif isinstance(op, str): raise RLGraphError( "Cannot derive Space from non-allowed op ({})!".format(op)) # A single numpy array. elif isinstance(op, np.ndarray): return BoxSpace.from_spec(spec=convert_dtype(str(op.dtype), "np"), shape=op.shape) elif isinstance(op, list): return try_space_inference_from_list(op) # No Space: e.g. the tf.no_op, a distribution (anything that's not a tensor). # PyTorch Tensors do not have get_shape so must check backend. elif hasattr(op, "dtype") is False or (get_backend() == "tf" and not hasattr(op, "get_shape")): return 0 # Some tensor: can be converted into a BoxSpace. else: shape = get_shape(op) # Unknown shape (e.g. a cond op). if shape is None: return 0 add_batch_rank = False add_time_rank = False time_major = False new_shape = list(shape) # New way: Detect via op._batch_rank and op._time_rank properties where these ranks are. if hasattr(op, "_batch_rank") and isinstance(op._batch_rank, int): add_batch_rank = True new_shape[op._batch_rank] = -1 # elif get_backend() == "pytorch": # if isinstance(op, torch.Tensor): # if op.dim() > 1 and shape[0] == 1: # add_batch_rank = True # new_shape[0] = 1 if hasattr(op, "_time_rank") and isinstance(op._time_rank, int): add_time_rank = True if op._time_rank == 0: time_major = True new_shape[op._time_rank] = -1 shape = tuple(n for n in new_shape if n != -1) # Old way: Detect automatically whether the first rank(s) are batch and/or time rank. if add_batch_rank is False and add_time_rank is False and shape != ( ) and shape[0] is None: if len(shape) > 1 and shape[1] is None: #raise RLGraphError( # "ERROR: Cannot determine time-major flag if both batch- and time-ranks are in an op w/o saying " # "which rank goes to which position!" #) shape = shape[2:] add_time_rank = True else: shape = shape[1:] add_batch_rank = True # TODO: If op._batch_rank and/or op._time_rank are not set, set them now. base_dtype = op.dtype.base_dtype if hasattr( op.dtype, "base_dtype") else op.dtype # PyTorch does not have a bool type if get_backend() == "pytorch": if op.dtype is torch.uint8: base_dtype = bool base_dtype_str = str(base_dtype) # FloatBox if "float" in base_dtype_str: return FloatBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=convert_dtype(base_dtype, "np")) # IntBox elif "int" in base_dtype_str: high = getattr(op, "_num_categories", None) return IntBox(high, shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=convert_dtype(base_dtype, "np")) # a BoolBox elif "bool" in base_dtype_str: return BoolBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) # a TextBox elif "string" in base_dtype_str: return TextBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) raise RLGraphError( "ERROR: Cannot derive Space from op '{}' (unknown type?)!".format(op))
def get_space_from_op(op, read_key_hints=False, dtype=None, low=None, high=None): """ Tries to re-create a Space object given some DataOp (e.g. a tf op). This is useful for shape inference on returned ops after having run through a graph_fn. Args: op (DataOp): The op to create a corresponding Space for. read_key_hints (bool): If True, tries to read type- and low/high-hints from the pattern of the Dict keys (str). - Preceding "I_": IntBox, "F_": FloatBox, "B_": BoolBox. - Succeeding "_low=0.0": Low value. - Succeeding "_high=1.0": High value. E.g. Dict key "F_somekey_low=0.0_high=2.0" indicates a FloatBox with low=0.0 and high=2.0. Dict key "I_somekey" indicates an intbox with no limits. Dict key "I_somekey_high=5" indicates an intbox with high=5 (values 0-4). Default: False. dtype (Optional[str]): An optional indicator, what the `dtype` of a BoxSpace should be. low (Optional[int,float]): An optional indicator, what the `low` property for a BoxSpace should be. high (Optional[int,float]): An optional indicator, what the `high` property for a BoxSpace should be. Returns: Space: The inferred Space object. """ # a Dict if isinstance(op, dict): # DataOpDict spec = {} add_batch_rank = False add_time_rank = False for key, value in op.items(): # Try to infer hints from the key. if read_key_hints is True: dtype, low, high = get_space_hints_from_dict_key(key) spec[key] = get_space_from_op(value, dtype=dtype, low=low, high=high) # Return if spec[key] == 0: return 0 if spec[key].has_batch_rank: add_batch_rank = True if spec[key].has_time_rank: add_time_rank = True return Dict(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # a Tuple elif isinstance(op, tuple): # DataOpTuple spec = [] add_batch_rank = False add_time_rank = False for i in op: space = get_space_from_op(i) if space == 0: return 0 spec.append(space) if spec[-1].has_batch_rank: add_batch_rank = True if spec[-1].has_time_rank: add_time_rank = True return Tuple(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # primitive Space -> infer from op dtype and shape else: low_high = {} if high is not None: low_high["high"] = high if low is not None: low_high["low"] = low # Op itself is a single value, simple python type. if isinstance(op, (bool, int, float)): return BoxSpace.from_spec(spec=(dtype or type(op)), shape=(), **low_high) elif isinstance(op, str): raise RLGraphError( "Cannot derive Space from non-allowed op ({})!".format(op)) # A single numpy array. elif isinstance(op, np.ndarray): return BoxSpace.from_spec(spec=convert_dtype(str(op.dtype), "np"), shape=op.shape, **low_high) elif isinstance(op, list): return try_space_inference_from_list(op, dtype=dtype, **low_high) # No Space: e.g. the tf.no_op, a distribution (anything that's not a tensor). # PyTorch Tensors do not have get_shape so must check backend. elif hasattr(op, "dtype") is False or (get_backend() == "tf" and not hasattr(op, "get_shape")): return 0 # Some tensor: can be converted into a BoxSpace. else: shape = get_shape(op) # Unknown shape (e.g. a cond op). if shape is None: return 0 add_batch_rank = False add_time_rank = False time_major = False new_shape = list(shape) # New way: Detect via op._batch_rank and op._time_rank properties where these ranks are. if hasattr(op, "_batch_rank") and isinstance(op._batch_rank, int): add_batch_rank = True new_shape[op._batch_rank] = -1 # elif get_backend() == "pytorch": # if isinstance(op, torch.Tensor): # if op.dim() > 1 and shape[0] == 1: # add_batch_rank = True # new_shape[0] = 1 if hasattr(op, "_time_rank") and isinstance(op._time_rank, int): add_time_rank = True if op._time_rank == 0: time_major = True new_shape[op._time_rank] = -1 shape = tuple(n for n in new_shape if n != -1) # Old way: Detect automatically whether the first rank(s) are batch and/or time rank. if add_batch_rank is False and add_time_rank is False and shape != ( ) and shape[0] is None: if len(shape) > 1 and shape[1] is None: #raise RLGraphError( # "ERROR: Cannot determine time-major flag if both batch- and time-ranks are in an op w/o saying " # "which rank goes to which position!" #) shape = shape[2:] add_time_rank = True else: shape = shape[1:] add_batch_rank = True # TODO: If op._batch_rank and/or op._time_rank are not set, set them now. base_dtype = op.dtype.base_dtype if hasattr( op.dtype, "base_dtype") else op.dtype # PyTorch does not have a bool type if get_backend() == "pytorch": if op.dtype is torch.uint8: base_dtype = bool base_dtype_str = str(base_dtype) # FloatBox if "float" in base_dtype_str: return FloatBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=convert_dtype(base_dtype, "np")) # IntBox elif "int" in base_dtype_str: high_ = high or getattr(op, "_num_categories", None) return IntBox(high_, shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=convert_dtype(base_dtype, "np")) # a BoolBox elif "bool" in base_dtype_str: return BoolBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) # a TextBox elif "string" in base_dtype_str: return TextBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) raise RLGraphError( "ERROR: Cannot derive Space from op '{}' (unknown type?)!".format(op))