def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False v_new = Variable(v.shape, v.order).change_order(target_orders[0]) op.replace_output(v, v_new, with_assert=False) v_new.transpose(v.order).replace(v, with_assert=False) return True
def test_transpose(): v1 = Variable([2, 3, 4, 5], OrderNHWC) v2 = v1.transpose(OrderNCHW) assert v2.shape == (2, 5, 3, 4), v2.shape assert v2.order == OrderNCHW assert isinstance(v2.output_from, Transpose) assert v2.output_from.inputs["x0"] == v1
def template(x_order=OrderNHWC, y_order=OrderNCHW, description: str = ""): vx = np.random.rand(2, 3, 4, 5) vy = np.transpose(vx, [x_order.axes_dict[a] for a in y_order.axes]) x = Variable(vx.shape, order=x_order) y = x.transpose(y_order) generate_kernel_test_case( description=f"Transpose {description}", backend=["webgpu", "webgl", "webassembly"], graph=Graph([x], [y]), inputs={x: vx}, expected={y: vy}, )