Exemple #1
0
        def handle_shape_gather(tensor):
            gather = get_producer(tensor, "Gather")
            if gather is None:
                return None

            data = gather.inputs[0]
            indices_tensor = gather.inputs[1]

            inp = get_input(get_producer(data, "Shape"))
            if inp is None or inp.shape is None:
                return None

            if not isinstance(indices_tensor, Constant):
                return None

            indices = indices_tensor.values
            if not indices.shape:  # Scalar-case
                shape = inp.shape[int(indices)]
                if misc.is_dynamic_dimension(shape):
                    return None
            else:
                shape = [inp.shape[index] for index in indices]
                if misc.is_dynamic_shape(shape):
                    return None

            return np.array(shape, dtype=np.int64)
Exemple #2
0
        def handle_shape(tensor):
            inp = get_input(get_producer(tensor, "Shape"))
            if inp is None:
                return None

            if inp.shape is None or misc.is_dynamic_shape(inp.shape):
                return None
            return np.array(inp.shape, dtype=np.int64)
Exemple #3
0
        def fold_shape_slice(tensor):
            slice = get_producer(tensor, "Slice")
            if slice is None:
                return None

            data = slice.inputs[0]

            if len(slice.inputs) >= 3:
                starts, ends = slice.inputs[1:3]
                if any(not isinstance(t, Constant) for t in [starts, ends]):
                    return None
                starts, ends = get_scalar_value(starts), get_scalar_value(ends)
            elif "starts" in slice.attrs and "ends" in slice.attrs:
                starts, ends = slice.attrs["starts"][0], slice.attrs["ends"][0]
            else:
                return None

            inp = get_input(get_producer(data, "Shape"))
            if inp is None or inp.shape is None:
                return None

            # For shape tensors, we can only slice on the 0th dimension.
            if len(slice.inputs) > 3:
                axes = slice.inputs[3]
                if not isinstance(axes, Constant):
                    return None

                if get_scalar_value(axes) != 0:
                    return None
            elif "axes" in slice.attrs:
                if slice.attrs["axes"][0] != 0:
                    return None

            steps = 1
            if len(slice.inputs) > 4:
                steps = slice.inputs[4]
                if not isinstance(steps, Constant):
                    return None
                steps = get_scalar_value(steps)
            elif "steps" in slice.attrs:
                steps = slice.attrs["steps"][0]

            shape = inp.shape[starts:ends:steps]
            if misc.is_dynamic_shape(shape):
                return None

            return np.array(shape, dtype=np.int64)
Exemple #4
0
        def lower_shape(tensor):
            if len(tensor.inputs) != 1:
                return None

            node = tensor.inputs[0]
            if node.op != "Shape":
                return None

            inp = node.inputs[0] # Shape node must have 1 input or it's malformed.

            # If the input was already found to be a constant, it will be folded anyway.
            if inp.name in graph_constants:
                return None

            if inp.shape is None or misc.is_dynamic_shape(inp.shape):
                return None
            return np.array(inp.shape, dtype=np.int64)
Exemple #5
0
        def handle_shape_slice(tensor):
            slice = get_producer(tensor, "Slice")
            if slice is None:
                return None

            data = slice.inputs[0]
            starts, ends = slice.inputs[1:3]

            inp = get_input(get_producer(data, "Shape"))
            if inp is None or inp.shape is None:
                return None

            if any(not isinstance(t, Constant) for t in [starts, ends]):
                return None

            def get_value(
                tensor
            ):  # Gets the integer value of a tensor with a single item
                if not tensor.shape:
                    return tensor.values
                else:
                    return list(tensor.values)[0]

            if len(slice.inputs) > 3:
                axes = slice.inputs[3]
                if not isinstance(axes, Constant):
                    return None

                if get_value(axes) != 0:
                    return None

            steps = 1
            if len(slice.inputs) > 4:
                steps = slice.inputs[4]
                if not isinstance(steps, Constant):
                    return None

                steps = get_value(steps)

            shape = inp.shape[get_value(starts):get_value(ends):steps]
            if misc.is_dynamic_shape(shape):
                return None

            return np.array(shape, dtype=np.int64)