def handle_shape_gather(tensor): gather = get_producer(tensor, "Gather") if gather is None: return None data = gather.inputs[0] indices_tensor = gather.inputs[1] inp = get_input(get_producer(data, "Shape")) if inp is None or inp.shape is None: return None if not isinstance(indices_tensor, Constant): return None indices = indices_tensor.values if not indices.shape: # Scalar-case shape = inp.shape[int(indices)] if misc.is_dynamic_dimension(shape): return None else: shape = [inp.shape[index] for index in indices] if misc.is_dynamic_shape(shape): return None return np.array(shape, dtype=np.int64)
def handle_shape(tensor): inp = get_input(get_producer(tensor, "Shape")) if inp is None: return None if inp.shape is None or misc.is_dynamic_shape(inp.shape): return None return np.array(inp.shape, dtype=np.int64)
def fold_shape_slice(tensor): slice = get_producer(tensor, "Slice") if slice is None: return None data = slice.inputs[0] if len(slice.inputs) >= 3: starts, ends = slice.inputs[1:3] if any(not isinstance(t, Constant) for t in [starts, ends]): return None starts, ends = get_scalar_value(starts), get_scalar_value(ends) elif "starts" in slice.attrs and "ends" in slice.attrs: starts, ends = slice.attrs["starts"][0], slice.attrs["ends"][0] else: return None inp = get_input(get_producer(data, "Shape")) if inp is None or inp.shape is None: return None # For shape tensors, we can only slice on the 0th dimension. if len(slice.inputs) > 3: axes = slice.inputs[3] if not isinstance(axes, Constant): return None if get_scalar_value(axes) != 0: return None elif "axes" in slice.attrs: if slice.attrs["axes"][0] != 0: return None steps = 1 if len(slice.inputs) > 4: steps = slice.inputs[4] if not isinstance(steps, Constant): return None steps = get_scalar_value(steps) elif "steps" in slice.attrs: steps = slice.attrs["steps"][0] shape = inp.shape[starts:ends:steps] if misc.is_dynamic_shape(shape): return None return np.array(shape, dtype=np.int64)
def lower_shape(tensor): if len(tensor.inputs) != 1: return None node = tensor.inputs[0] if node.op != "Shape": return None inp = node.inputs[0] # Shape node must have 1 input or it's malformed. # If the input was already found to be a constant, it will be folded anyway. if inp.name in graph_constants: return None if inp.shape is None or misc.is_dynamic_shape(inp.shape): return None return np.array(inp.shape, dtype=np.int64)
def handle_shape_slice(tensor): slice = get_producer(tensor, "Slice") if slice is None: return None data = slice.inputs[0] starts, ends = slice.inputs[1:3] inp = get_input(get_producer(data, "Shape")) if inp is None or inp.shape is None: return None if any(not isinstance(t, Constant) for t in [starts, ends]): return None def get_value( tensor ): # Gets the integer value of a tensor with a single item if not tensor.shape: return tensor.values else: return list(tensor.values)[0] if len(slice.inputs) > 3: axes = slice.inputs[3] if not isinstance(axes, Constant): return None if get_value(axes) != 0: return None steps = 1 if len(slice.inputs) > 4: steps = slice.inputs[4] if not isinstance(steps, Constant): return None steps = get_value(steps) shape = inp.shape[get_value(starts):get_value(ends):steps] if misc.is_dynamic_shape(shape): return None return np.array(shape, dtype=np.int64)