Example #1
0
def get_space_from_op(op):
    """
    Tries to re-create a Space object given some DataOp (e.g. a tf op).
    This is useful for shape inference on returned ops after having run through a graph_fn.

    Args:
        op (DataOp): The op to create a corresponding Space for.

    Returns:
        Space: The inferred Space object.
    """
    # a Dict
    if isinstance(op, dict):  # DataOpDict
        spec = {}
        add_batch_rank = False
        add_time_rank = False
        for key, value in op.items():
            spec[key] = get_space_from_op(value)
            if spec[key].has_batch_rank:
                add_batch_rank = True
            if spec[key].has_time_rank:
                add_time_rank = True
        return Dict(spec,
                    add_batch_rank=add_batch_rank,
                    add_time_rank=add_time_rank)
    # a Tuple
    elif isinstance(op, tuple):  # DataOpTuple
        spec = []
        add_batch_rank = False
        add_time_rank = False
        for i in op:
            space = get_space_from_op(i)
            if space == 0:
                return 0
            spec.append(space)
            if spec[-1].has_batch_rank:
                add_batch_rank = True
            if spec[-1].has_time_rank:
                add_time_rank = True
        return Tuple(spec,
                     add_batch_rank=add_batch_rank,
                     add_time_rank=add_time_rank)
    # primitive Space -> infer from op dtype and shape
    else:
        # Op itself is a single value, simple python type.
        if isinstance(op, (bool, int, float)):
            return BoxSpace.from_spec(spec=type(op), shape=())
        elif isinstance(op, str):
            raise RLGraphError(
                "Cannot derive Space from non-allowed op ({})!".format(op))
        # A single numpy array.
        elif isinstance(op, np.ndarray):
            return BoxSpace.from_spec(spec=convert_dtype(str(op.dtype), "np"),
                                      shape=op.shape)
        elif isinstance(op, list):
            return try_space_inference_from_list(op)
        # No Space: e.g. the tf.no_op, a distribution (anything that's not a tensor).
        # PyTorch Tensors do not have get_shape so must check backend.
        elif hasattr(op, "dtype") is False or (get_backend() == "tf" and
                                               not hasattr(op, "get_shape")):
            return 0
        # Some tensor: can be converted into a BoxSpace.
        else:
            shape = get_shape(op)
            # Unknown shape (e.g. a cond op).
            if shape is None:
                return 0
            add_batch_rank = False
            add_time_rank = False
            time_major = False
            new_shape = list(shape)

            # New way: Detect via op._batch_rank and op._time_rank properties where these ranks are.
            if hasattr(op, "_batch_rank") and isinstance(op._batch_rank, int):
                add_batch_rank = True
                new_shape[op._batch_rank] = -1

            # elif get_backend() == "pytorch":
            #     if isinstance(op, torch.Tensor):
            #         if op.dim() > 1 and shape[0] == 1:
            #             add_batch_rank = True
            #             new_shape[0] = 1
            if hasattr(op, "_time_rank") and isinstance(op._time_rank, int):
                add_time_rank = True
                if op._time_rank == 0:
                    time_major = True
                new_shape[op._time_rank] = -1
            shape = tuple(n for n in new_shape if n != -1)

            # Old way: Detect automatically whether the first rank(s) are batch and/or time rank.
            if add_batch_rank is False and add_time_rank is False and shape != (
            ) and shape[0] is None:
                if len(shape) > 1 and shape[1] is None:
                    #raise RLGraphError(
                    #    "ERROR: Cannot determine time-major flag if both batch- and time-ranks are in an op w/o saying "
                    #    "which rank goes to which position!"
                    #)
                    shape = shape[2:]
                    add_time_rank = True
                else:
                    shape = shape[1:]
                add_batch_rank = True

            # TODO: If op._batch_rank and/or op._time_rank are not set, set them now.

            base_dtype = op.dtype.base_dtype if hasattr(
                op.dtype, "base_dtype") else op.dtype
            # PyTorch does not have a bool type
            if get_backend() == "pytorch":
                if op.dtype is torch.uint8:
                    base_dtype = bool
            base_dtype_str = str(base_dtype)

            # FloatBox
            if "float" in base_dtype_str:
                return FloatBox(shape=shape,
                                add_batch_rank=add_batch_rank,
                                add_time_rank=add_time_rank,
                                time_major=time_major,
                                dtype=convert_dtype(base_dtype, "np"))
            # IntBox
            elif "int" in base_dtype_str:
                high = getattr(op, "_num_categories", None)
                return IntBox(high,
                              shape=shape,
                              add_batch_rank=add_batch_rank,
                              add_time_rank=add_time_rank,
                              time_major=time_major,
                              dtype=convert_dtype(base_dtype, "np"))
            # a BoolBox
            elif "bool" in base_dtype_str:
                return BoolBox(shape=shape,
                               add_batch_rank=add_batch_rank,
                               add_time_rank=add_time_rank,
                               time_major=time_major)
            # a TextBox
            elif "string" in base_dtype_str:
                return TextBox(shape=shape,
                               add_batch_rank=add_batch_rank,
                               add_time_rank=add_time_rank,
                               time_major=time_major)

    raise RLGraphError(
        "ERROR: Cannot derive Space from op '{}' (unknown type?)!".format(op))
Example #2
0
def get_space_from_op(op,
                      read_key_hints=False,
                      dtype=None,
                      low=None,
                      high=None):
    """
    Tries to re-create a Space object given some DataOp (e.g. a tf op).
    This is useful for shape inference on returned ops after having run through a graph_fn.

    Args:
        op (DataOp): The op to create a corresponding Space for.

        read_key_hints (bool): If True, tries to read type- and low/high-hints from the pattern of the Dict keys (str).
            - Preceding "I_": IntBox, "F_": FloatBox, "B_": BoolBox.
            - Succeeding "_low=0.0": Low value.
            - Succeeding "_high=1.0": High value.
            E.g. Dict key "F_somekey_low=0.0_high=2.0" indicates a FloatBox with low=0.0 and high=2.0.
                 Dict key "I_somekey" indicates an intbox with no limits.
                 Dict key "I_somekey_high=5" indicates an intbox with high=5 (values 0-4).

            Default: False.

        dtype (Optional[str]): An optional indicator, what the `dtype` of a BoxSpace should be.
        low (Optional[int,float]): An optional indicator, what the `low` property for a BoxSpace should be.
        high (Optional[int,float]): An optional indicator, what the `high` property for a BoxSpace should be.

    Returns:
        Space: The inferred Space object.
    """
    # a Dict
    if isinstance(op, dict):  # DataOpDict
        spec = {}
        add_batch_rank = False
        add_time_rank = False
        for key, value in op.items():
            # Try to infer hints from the key.
            if read_key_hints is True:
                dtype, low, high = get_space_hints_from_dict_key(key)
            spec[key] = get_space_from_op(value,
                                          dtype=dtype,
                                          low=low,
                                          high=high)
            # Return
            if spec[key] == 0:
                return 0
            if spec[key].has_batch_rank:
                add_batch_rank = True
            if spec[key].has_time_rank:
                add_time_rank = True
        return Dict(spec,
                    add_batch_rank=add_batch_rank,
                    add_time_rank=add_time_rank)
    # a Tuple
    elif isinstance(op, tuple):  # DataOpTuple
        spec = []
        add_batch_rank = False
        add_time_rank = False
        for i in op:
            space = get_space_from_op(i)
            if space == 0:
                return 0
            spec.append(space)
            if spec[-1].has_batch_rank:
                add_batch_rank = True
            if spec[-1].has_time_rank:
                add_time_rank = True
        return Tuple(spec,
                     add_batch_rank=add_batch_rank,
                     add_time_rank=add_time_rank)

    # primitive Space -> infer from op dtype and shape
    else:
        low_high = {}
        if high is not None:
            low_high["high"] = high
        if low is not None:
            low_high["low"] = low
        # Op itself is a single value, simple python type.
        if isinstance(op, (bool, int, float)):
            return BoxSpace.from_spec(spec=(dtype or type(op)),
                                      shape=(),
                                      **low_high)
        elif isinstance(op, str):
            raise RLGraphError(
                "Cannot derive Space from non-allowed op ({})!".format(op))
        # A single numpy array.
        elif isinstance(op, np.ndarray):
            return BoxSpace.from_spec(spec=convert_dtype(str(op.dtype), "np"),
                                      shape=op.shape,
                                      **low_high)
        elif isinstance(op, list):
            return try_space_inference_from_list(op, dtype=dtype, **low_high)
        # No Space: e.g. the tf.no_op, a distribution (anything that's not a tensor).
        # PyTorch Tensors do not have get_shape so must check backend.
        elif hasattr(op, "dtype") is False or (get_backend() == "tf" and
                                               not hasattr(op, "get_shape")):
            return 0
        # Some tensor: can be converted into a BoxSpace.
        else:
            shape = get_shape(op)
            # Unknown shape (e.g. a cond op).
            if shape is None:
                return 0
            add_batch_rank = False
            add_time_rank = False
            time_major = False
            new_shape = list(shape)

            # New way: Detect via op._batch_rank and op._time_rank properties where these ranks are.
            if hasattr(op, "_batch_rank") and isinstance(op._batch_rank, int):
                add_batch_rank = True
                new_shape[op._batch_rank] = -1

            # elif get_backend() == "pytorch":
            #     if isinstance(op, torch.Tensor):
            #         if op.dim() > 1 and shape[0] == 1:
            #             add_batch_rank = True
            #             new_shape[0] = 1
            if hasattr(op, "_time_rank") and isinstance(op._time_rank, int):
                add_time_rank = True
                if op._time_rank == 0:
                    time_major = True
                new_shape[op._time_rank] = -1
            shape = tuple(n for n in new_shape if n != -1)

            # Old way: Detect automatically whether the first rank(s) are batch and/or time rank.
            if add_batch_rank is False and add_time_rank is False and shape != (
            ) and shape[0] is None:
                if len(shape) > 1 and shape[1] is None:
                    #raise RLGraphError(
                    #    "ERROR: Cannot determine time-major flag if both batch- and time-ranks are in an op w/o saying "
                    #    "which rank goes to which position!"
                    #)
                    shape = shape[2:]
                    add_time_rank = True
                else:
                    shape = shape[1:]
                add_batch_rank = True

            # TODO: If op._batch_rank and/or op._time_rank are not set, set them now.

            base_dtype = op.dtype.base_dtype if hasattr(
                op.dtype, "base_dtype") else op.dtype
            # PyTorch does not have a bool type
            if get_backend() == "pytorch":
                if op.dtype is torch.uint8:
                    base_dtype = bool
            base_dtype_str = str(base_dtype)

            # FloatBox
            if "float" in base_dtype_str:
                return FloatBox(shape=shape,
                                add_batch_rank=add_batch_rank,
                                add_time_rank=add_time_rank,
                                time_major=time_major,
                                dtype=convert_dtype(base_dtype, "np"))
            # IntBox
            elif "int" in base_dtype_str:
                high_ = high or getattr(op, "_num_categories", None)
                return IntBox(high_,
                              shape=shape,
                              add_batch_rank=add_batch_rank,
                              add_time_rank=add_time_rank,
                              time_major=time_major,
                              dtype=convert_dtype(base_dtype, "np"))
            # a BoolBox
            elif "bool" in base_dtype_str:
                return BoolBox(shape=shape,
                               add_batch_rank=add_batch_rank,
                               add_time_rank=add_time_rank,
                               time_major=time_major)
            # a TextBox
            elif "string" in base_dtype_str:
                return TextBox(shape=shape,
                               add_batch_rank=add_batch_rank,
                               add_time_rank=add_time_rank,
                               time_major=time_major)

    raise RLGraphError(
        "ERROR: Cannot derive Space from op '{}' (unknown type?)!".format(op))