Пример #1
0
 def run_test(inp_shp, fn):
     np_x = np.random.randn(*inp_shp).astype(np.float32)
     mge_x = megengine.tensor(np_x)
     out_ref = fn(np_x)
     if symbolic is not None:
         fn = jit.trace(symbolic=symbolic)(fn)
     for i in range(3):
         out = fn(mge_x)
         np.testing.assert_equal(out.numpy(), out_ref)
Пример #2
0
def test_sgd_momentum(monkeypatch, trace_mode, inplace_mode):
    with monkeypatch.context() as mk:
        mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))

        def train_func(data, *, model=None, optim=None, gm=None):
            optim.clear_grad()
            with gm:
                loss = net(data)
                gm.backward(loss)
            optim.step()
            return loss

        if trace_mode is not None:
            train_func = trace(symbolic=trace_mode)(train_func)

        def eval_func(data, *, model=None, optim=None, gm=None):
            loss = net(data)
            return loss

        if trace_mode is not None:
            eval_func = trace(symbolic=trace_mode)(eval_func)

        net = Simple()
        optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
        gm = ad.GradManager().attach(net.parameters())
        data = tensor([2.34])
        train_func(data, model=net, optim=optim, gm=gm)
        np.testing.assert_almost_equal(
            optim._state[net.a]["momentum_buffer"].numpy(), 2.34)

        # do 3 steps of infer
        for _ in range(3):
            loss = eval_func(data)
            np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34),
                                           5)
            np.testing.assert_almost_equal(
                optim._state[net.a]["momentum_buffer"].numpy(), 2.34)

        # do a step of train
        train_func(data, model=net, optim=optim, gm=gm)
        np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
        np.testing.assert_almost_equal(
            optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34,
            5)
Пример #3
0
        def set_b(symbolic, *args):
            # print('set_b:', args)
            def set_func(inp, val):
                for i in mge_idx:
                    if isinstance(i, (list, Tensor)):
                        return inp.set_ai(val)[mge_idx]
                return inp.set_subtensor(val)[mge_idx]

            func = trace(set_func, symbolic=symbolic)
            return func(*args)
Пример #4
0
        def get_b(symbolic, *args):
            # print('get_b:', args)
            def get_func(inp):
                for i in mge_idx:
                    if isinstance(i, (list, Tensor)):
                        return inp.ai[mge_idx]
                return inp[mge_idx]

            func = trace(get_func, symbolic=symbolic)
            return func(*args)
Пример #5
0
 def run_test(func, args, ref_shape, is_trace, sym=False):
     args = [tensor(t, dtype="float32") for t in args]
     if is_trace:
         func = trace(symbolic=sym)(func)
         for _ in range(3):
             out = func(*args)
             assert out.numpy().shape == ref_shape
     else:
         out = func(*args)
         assert out.numpy().shape == ref_shape, out.numpy().shape
Пример #6
0
    def worker(data, expect):
        rank = dist.get_rank()
        inp = tensor(data[rank])

        def func():
            output = scatter(inp, axis=axis)
            return output

        func = trace(symbolic=symbolic)(func)
        output = func()
        assert np.allclose(output.numpy(), expect[rank])
Пример #7
0
def run_train(
    model_path,
    use_jit,
    use_symbolic,
    sublinear_memory_config=None,
    max_err=None,
    use_adaptive_pooling=False,
):

    """
    Load the model with test cases and run the training for one iter.
    The loss and updated weights are compared with reference value to verify the correctness.

    Dump a new file with updated result by calling update_model
    if you think the test fails due to numerical rounding errors instead of bugs.
    Please think twice before you do so.

    """
    net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling)
    checkpoint = mge.load(model_path)
    net.load_state_dict(checkpoint["net_init"])
    lr = checkpoint["sgd_lr"]
    opt = SGD(net.parameters(), lr=lr)
    gm = ad.GradManager().attach(net.parameters())

    data = Tensor(checkpoint["data"], dtype=np.float32)
    label = Tensor(checkpoint["label"], dtype=np.int32)

    if max_err is None:
        max_err = 1e-5

    train_func = train
    if use_jit:
        train_func = jit.trace(
            train_func,
            symbolic=use_symbolic,
            sublinear_memory_config=sublinear_memory_config,
        )

    opt.clear_grad()
    loss = train_func(data, label, net, opt, gm)
    opt.step()

    np.testing.assert_allclose(loss.numpy(), checkpoint["loss"], atol=max_err)

    for param, param_ref in zip(
        net.state_dict().items(), checkpoint["net_updated"].items()
    ):
        assert param[0] == param_ref[0]
        if "bn" in param[0]:
            ref = param_ref[1].reshape(param[1].shape)
            np.testing.assert_allclose(param[1], ref, atol=max_err)
        else:
            np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
Пример #8
0
def test_grad_scaler():
    def f():
        gm = GradManager()
        scaler = GradScaler()

        x = mge.tensor(1.0)
        for _ in range(3):
            with gm:
                y = x + 1
                gm.attach(y)
                loss = y + 1
                scaler.backward(gm, loss, unscale_grad=False)
            np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor)
            scaler.unscale(gm.attached_tensors())
            np.testing.assert_equal(y.grad.numpy(), 1)
        # test handle None elements
        scaler.unscale(gm.attached_tensors())

    f()
    trace(f)()
Пример #9
0
    def _reset_jit_graph(self, impl: callable):
        """We override this func to attach weight clipping after default training step"""
        traced_obj = jit.trace(impl)

        def _(*args, **kwargs):
            ret = traced_obj(*args, **kwargs)
            if self.training:
                self._apply_lipshitz_constraint(
                )  # dynamically apply weight clipping
            return ret

        return _
Пример #10
0
def test_rng_empty_tensor(is_symbolic):
    set_global_seed(1024)
    shapes = [
        (0, ),
        (0, 0, 0),
        (10, 0, 10),
    ]

    def fn(shape):
        o1 = random.uniform(0, 1, shape)
        o2 = random.normal(0, 1, shape)
        o3 = random.gamma(2, 1, shape)
        o4 = random.beta(2, 1, shape)
        o5 = random.poisson(2, shape)
        return o1, o2, o3, o4, o5

    for shape in shapes:
        if is_symbolic is not None:
            fn_ = jit.trace(symbolic=is_symbolic)(fn)
        else:
            fn_ = fn
        for _ in range(3):
            outs = fn_(shape)
            for out in outs:
                np.testing.assert_equal(out.numpy().shape, shape)
            if is_symbolic is None:
                break

    def fn2(n):
        return random.permutation(n=n)

    if is_symbolic is not None:
        fn2 = jit.trace(symbolic=is_symbolic)(fn2)

    for _ in range(3):
        out = fn2(0)
        np.testing.assert_equal(out.numpy().shape, (0, ))
        if is_symbolic is None:
            break
    mge.core.set_option("async_level", 2)
Пример #11
0
 def run(use_trace, symbolic):
     a = tensor(np.array([1926.0817], dtype=np.float32))
     net = Sigmoid()
     func_run = run_saved_context
     if use_trace:
         func_run = trace(run_saved_context, symbolic=symbolic)
     s = func_run(a, net=net)
     s2 = F.sigmoid(a)
     assertTensorClose(s.numpy(), s2.numpy())
     assertTensorClose(
         F.grad(s, a, use_virtual_grad=False).numpy(),
         F.grad(s2, a, use_virtual_grad=False).numpy(),
     )
Пример #12
0
    def run_index(index):
        inp, outp = get_param(cases, index)
        inp_tensor = [make_tensor(inpi, network) for inpi in inp]

        if test_trace and not network:
            copied_inp = inp_tensor.copy()
            for symbolic in [False, True]:
                traced_func = trace(symbolic=symbolic)(func)

                for _ in range(3):
                    traced_results = traced_func(*copied_inp, **kwargs)
                check_results(traced_results, outp)

            dumped_func = trace(symbolic=True, capture_as_const=True)(func)
            dumped_results = dumped_func(*copied_inp, **kwargs)
            check_results(dumped_results, outp)

            file = io.BytesIO()
            dump_info = dumped_func.dump(file)
            file.seek(0)

            # arg_name has pattern arg_xxx, xxx is int value
            def take_number(arg_name):
                return int(arg_name.split("_")[-1])

            input_names = dump_info[4]
            inps_np = [i.numpy() for i in copied_inp]
            input_names.sort(key=take_number)
            inp_dict = dict(zip(input_names, inps_np))
            infer_cg = cgtools.GraphInference(file)

            # assume #outputs == 1
            loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
            check_results(loaded_results, outp,
                          check_shape=False)  # scalar info lost

        results = func(*inp_tensor, **kwargs)
        check_results(results, outp, check_shape=(network is None))
Пример #13
0
def test_pixel_shuffle_symbolic(is_symbolic, type):
    def fn(inp, upscale_factor):
        return F.pixel_shuffle(inp, upscale_factor=upscale_factor)

    if is_symbolic is not None:
        fn = jit.trace(symbolic=is_symbolic)(fn)

    inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5).astype(type))
    golden = pixel_shuffle(inp, 2)
    for _ in range(3):
        out = fn(inp, 2)
        np.testing.assert_equal(out.numpy(), golden)
        if is_symbolic is None:
            break
Пример #14
0
def test_matmul_empty_tensor(shape_a, shape_b, is_symbolic):
    def func(a, b):
        return F.matmul(a, b)

    if is_symbolic is not None:
        func = jit.trace(symbolic=is_symbolic)(func)

    a = tensor(np.random.randn(*shape_a))
    b = tensor(np.random.randn(*shape_b))
    for _ in range(3):
        out = func(a, b)
        assert np.all(out.numpy() == 0)
        if is_symbolic is None:
            break
Пример #15
0
def test_batchnorm_change_batchsize():
    data_shape = (2, 3, 8, 8)
    real_shape = (4, 3, 8, 8)
    data = np.random.random(data_shape).astype(np.float32)
    d = np.random.random(real_shape).astype(np.float32)

    bn = BatchNorm2d(3)
    f = trace(bn)
    f(data)

    y1 = f(d)

    y0 = bn(tensor(d))

    assertTensorClose(y0.numpy(), y1.numpy())
Пример #16
0
def test_copy_empty(shape, device_src, device_dst, is_symbolic):
    inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)

    def func(inp):
        return F.copy(inp, device_dst)

    if is_symbolic is not None:
        func = trace(symbolic=is_symbolic)(func)

    for _ in range(3):
        out = func(inp)
        assert out.numpy().shape == shape
        assert out.device == device_dst
        if is_symbolic is None:
            break
Пример #17
0
def test_jit_trace():
    module = MyModule()
    module.eval()
    x = F.ones((1, 8, 14, 14))
    expect = module(x)
    traced_module = trace_module(module, x)
    func = trace(traced_module, capture_as_const=True)
    np.testing.assert_array_equal(func(x), expect)
    model = io.BytesIO()
    func.dump(model)
    model.seek(0)
    infer_cg = cgtools.GraphInference(model)
    np.testing.assert_allclose(list(infer_cg.run(x.numpy()).values())[0],
                               expect,
                               atol=1e-6)
Пример #18
0
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
    inp = tensor(np.random.randn(*shape).astype("float32"))

    def func(inp):
        return F.roll(inp, shifts, axis)

    if is_symbolic is not None:
        func = trace(symbolic=is_symbolic)(func)

    out_ref = np.roll(inp.numpy(), shifts, axis)
    for _ in range(3):
        out = F.roll(inp, shifts, axis)
        np.testing.assert_equal(out.numpy(), out_ref)
        if is_symbolic is None:
            break
Пример #19
0
    def worker():
        x = F.ones([3, 10], dtype="float32")
        m = M.Linear(10, 10)

        def func():
            with GradManager().attach(m.parameters()) as gm:
                if dist.get_rank() == 0:
                    y = m(x)
                else:
                    y = x
                y = F.distributed.broadcast(y)
                gm.backward(y)

        if trace_mode is not None:
            func = trace(symbolic=trace_mode)(func)
        func()
Пример #20
0
    def worker():
        m = M.Linear(10, 10)
        x = F.ones([3, 10], dtype="float32")

        def func():
            with GradManager().attach(m.parameters()) as gm:
                y = m(x)
                y = F.distributed.reduce_sum(y)
                if dist.get_rank() == 0:
                    loss = (2 * y + 1).mean()
                    gm.backward(loss)
                else:
                    gm.backward()

        if trace_mode is not None:
            func = trace(symbolic=trace_mode)(func)
        func()
Пример #21
0
def run_test(
    model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None,
):

    """
    Load the model with test cases and run the training for one iter.
    The loss and updated weights are compared with reference value to verify the correctness.

    Dump a new file with updated result by calling update_model
    if you think the test fails due to numerical rounding errors instead of bugs.
    Please think twice before you do so.

    """
    net = MnistNet(has_bn=True)
    checkpoint = mge.load(model_path)
    net.load_state_dict(checkpoint["net_init"])
    lr = checkpoint["sgd_lr"]
    opt = SGD(net.parameters(), lr=lr)

    data = tensor(dtype=np.float32)
    label = tensor(dtype=np.int32)
    data.set_value(checkpoint["data"])
    label.set_value(checkpoint["label"])

    if max_err is None:
        max_err = 1e-5

    train_func = train
    if use_jit:
        train_func = jit.trace(
            train_func,
            symbolic=use_symbolic,
            sublinear_memory_config=sublinear_memory_config,
        )

    opt.zero_grad()
    loss = train_func(data, label, net=net, opt=opt)
    opt.step()

    assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err)

    for param, param_ref in zip(
        net.state_dict().items(), checkpoint["net_updated"].items()
    ):
        assert param[0] == param_ref[0]
        assertTensorClose(param[1], param_ref[1], max_err=max_err)
Пример #22
0
def test_PermutationRNG(symbolic):
    m1 = RNG(seed=111, device="xpu0")
    m2 = RNG(seed=111, device="xpu1")
    m3 = RNG(seed=222, device="xpu0")
    out1 = m1.permutation(1000)
    out1_ = m1.uniform(size=(1000, ))
    out2 = m2.permutation(1000)
    out3 = m3.permutation(1000)

    np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
    assert out1.device == "xpu0" and out2.device == "xpu1"
    assert not (out1.numpy() == out3.numpy()).all()
    assert not (out1.numpy() == out1_.numpy()).all()

    out = m1.permutation(1000)
    out_shp = out.shape
    if isinstance(out_shp, tuple):
        assert out_shp == (1000, )
    else:
        assert all(out.shape.numpy() == np.array([1000]))

    def sum_result(res, fun):
        return sum(
            [1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])

    assert sum_result(out, lambda x: x) < 500
    assert sum_result(out, np.sort) == 1000

    def func():
        out = m1.permutation(Tensor(7))
        out_shp = out.shape
        if isinstance(out_shp, tuple):
            assert out_shp == (1, )
        else:
            assert all(out.shape.numpy() == np.array([1]))
        n, m = 6, 3
        out = m1.permutation(
            Tensor(np.arange(n * m), dtype="float32").reshape(n, m))
        out_shp = out.shape
        if isinstance(out_shp, tuple):
            assert out_shp == (n, m)
        else:
            assert all(out.shape.numpy() == np.array([n, m]))

    func = trace(symbolic=symbolic)(func)
    func()
Пример #23
0
def run_frozen_bn(BNModule, is_training, use_trace, use_symbolic):
    nchannel = 3
    m = BNModule(nchannel, freeze=True)
    if is_training:
        m.train()
    else:
        m.eval()
    var = 4.0
    bias = 1.0
    shape = (1, nchannel, 1, 1)
    m.running_var[...] = var * F.ones(shape)
    m.running_mean[...] = bias * F.ones(shape)

    saved_var = m.running_var.numpy()
    saved_mean = m.running_mean.numpy()
    saved_wt = m.weight.numpy()
    saved_bias = m.bias.numpy()

    gm = ad.GradManager().attach(m.parameters())
    optim = optimizer.SGD(m.parameters(), lr=1.0)
    optim.clear_grad()

    data = np.random.random((6, nchannel, 2, 2)).astype("float32")

    def train_fn(d):
        for _ in range(3):
            with gm:
                loss = m(d).mean()
                gm.backward(loss)
            optim.step()
        return loss

    if use_trace:
        train_fn = trace(train_fn, symbolic=use_symbolic)

    for _ in range(3):
        loss = train_fn(megengine.tensor(data))
        if not is_training:
            np.testing.assert_equal(m.running_var.numpy(), saved_var)
            np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
            np.testing.assert_almost_equal(
                loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5
            )
        np.testing.assert_equal(m.weight.numpy(), saved_wt)
        np.testing.assert_equal(m.bias.numpy(), saved_bias)
Пример #24
0
    def worker(data):
        rank = dist.get_rank()
        inp = tensor(data[rank])

        def func():
            all_to_all_output = all_to_all(inp,
                                           split_axis=split_axis,
                                           concat_axis=concat_axis)
            gather_C = gather(inp, axis=concat_axis)
            gather_B = gather(all_to_all_output, axis=split_axis)
            if rank == 0:
                return gather_B, gather_C
            return all_to_all_output

        func = trace(symbolic=symbolic)(func)
        ret = func()
        if rank == 0:
            assert np.allclose(ret[0], ret[1])
Пример #25
0
def test_profiler(format, trace_mode):
    tempdir = tempfile.TemporaryDirectory()
    profile_prefix = tempdir.name
    profile_path = os.path.join(profile_prefix,
                                "{}.{}".format(os.getpid(), format))

    def infer():
        with scope("my_scope"):
            oup = Simple()(tensor([1.23], dtype="float32"))
            return oup

    if trace_mode:
        infer = trace(symbolic=trace_mode)(infer)

    with Profiler(profile_prefix, format=format):
        infer()

    print(profile_path)
    assert os.path.exists(profile_path), "profiling results not found"

    if format == "chrome_timeline.json":
        with open(profile_path, "r") as f:
            events = json.load(f)
        if isinstance(events, dict):
            assert "traceEvents" in events
            events = events["traceEvents"]
        prev_ts = {}
        scope_count = 0
        for event in events:
            if "dur" in event:
                assert event["dur"] >= 0
            elif "ts" in event and "tid" in event:
                ts = event["ts"]
                tid = event["tid"]
                if ts != 0:
                    assert (tid not in prev_ts) or prev_ts[tid] <= ts
                    prev_ts[tid] = ts
            if "name" in event and event["name"] == "my_scope":
                scope_count += 1
        assert scope_count > 0 and scope_count % 2 == 0
Пример #26
0
def test_broadcast_on_empty_tensor(is_trace):
    input1_shape = (100, 0, 1)
    output1_shape = (100, 0, 10)
    data1 = tensor(np.random.random(input1_shape).astype(np.float32))

    input2_shape = (10, 0)
    output2_shape = (10, 10, 0)
    data2 = tensor(np.random.random(input2_shape).astype(np.float32))

    input3_shape = (0, 0, 1, 10)
    output3_shape = (10, 0, 0, 10, 10)
    data3 = tensor(np.random.random(input3_shape).astype(np.float32))

    def comp(out, target_shp):
        assert out._tuple_shape == target_shp

    def func(x, shp):
        return F.broadcast_to(x, shp)

    cases = [
        [data1, output1_shape],
        [data2, output2_shape],
        [data3, output3_shape],
    ]

    def test(func, inp, comp, target_shp):
        out = func(inp, target_shp)
        comp(out, target_shp)

    if is_trace:
        for symbolic in [False, True]:
            for inp, target_shp in cases:
                func_traced = trace(symbolic=symbolic)(func)
                test(func_traced, inp, comp, target_shp)
                test(func_traced, inp, comp, target_shp)
                test(func_traced, inp, comp, target_shp)
    else:
        for inp, target_shp in cases:
            test(func, inp, comp, target_shp)
Пример #27
0
def test_batchnorm_empty_tensor(dim, is_symbolic):
    if dim == 1:
        m = BatchNorm1d(4, affine=True)
        inp = mge.tensor(np.random.randn(0, 4, 0).astype("float32"))
    elif dim == 2:
        m = BatchNorm2d(4, affine=True)
        inp = mge.tensor(np.random.randn(0, 4, 0, 0).astype("float32"))
    else:
        raise NotImplementedError

    m.train()

    def fn(inp):
        return m(inp)

    if is_symbolic is not None:
        fn = jit.trace(symbolic=is_symbolic)(fn)
    for _ in range(3):
        out = fn(inp)
        np.testing.assert_equal(out.numpy(), inp)
        if is_symbolic is None:
            break
Пример #28
0
def test_trace_advance_indexing(shape_mode):
    funcs = [
        lambda x, i: x[i],
        # lambda x, i, j: x[i, j],  # FIXME
        lambda x, i, j: x[i, :, j, ...],
        # lambda x, start, end: x[start:end],  # FIXME
        lambda x, start, end: x[:, 0, start:end, ..., 1],
        lambda x, vec: x[vec],
        lambda x, vec: x[vec, ..., 0, 1:3],
        lambda x, vec: x[vec, vec[0], vec[1]],
        # lambda x, i, start, end, vec: x[i, ..., :, vec, start:end],  # FIXME
        lambda x, mask: x[mask],
    ]

    inputs = {
        "x": np.random.randn(5, 5, 5, 5, 5).astype("float32"),
        "i": 0,
        "j": 2,
        "start": 1,
        "end": 3,
        "vec": [1, 2, 3],
        "mask": np.random.randn(5, 5, 5, 5, 5) >= 0,
    }
    for f in funcs:
        sig = inspect.signature(f)
        param_names = list(sig._parameters.keys())
        params = {}
        params_np = {}
        f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode)
        for name in param_names:
            params[name] = tensor(inputs[name])
            params_np[name] = inputs[name]
        expected = f(**params_np)
        result_imperative = f(**params)
        np.testing.assert_equal(expected, result_imperative.numpy())
        for _ in range(3):
            result_trace = f_traced(**params)
            np.testing.assert_equal(expected, result_trace.numpy())
Пример #29
0
def test_sort_empty(is_symbolic):
    data_shapes = [
        (0,),
        (10, 0),
    ]

    def fn(x):
        return F.sort(x)

    for shape in data_shapes:
        if is_symbolic is not None:
            fn_ = jit.trace(symbolic=is_symbolic)(fn)
        else:
            fn_ = fn
        data = np.random.random(shape).astype(np.float32)
        for _ in range(3):
            outs = fn_(tensor(data))
            ref_outs = (np.sort(data), np.argsort(data))
            assert len(ref_outs) == len(outs)
            for i in range(len(outs)):
                np.testing.assert_equal(outs[i].numpy(), ref_outs[i])
        if is_symbolic is None:
            break
Пример #30
0
def test_nms(is_symbolic):
    def fn(inp, scores):
        return F.vision.nms(
            inp,
            scores=scores,
            iou_thresh=0.5,
            max_output=None if is_symbolic is None else 4,
        )

    if is_symbolic is not None:
        fn = jit.trace(symbolic=is_symbolic)(fn)

    x = np.array(
        [
            [0, 0, 100, 100],
            [10, 10, 100, 100],
            [50, 50, 100, 100],
            [100, 100, 150, 150],
        ],
        dtype=np.float32,
    )
    inp = tensor(x)
    scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
    for _ in range(3):
        result = fn(inp, scores=scores)
        np.testing.assert_equal(result.numpy(),
                                np.array([2, 1, 3], dtype=np.int32))

    x = np.array(
        [],
        dtype=np.float32,
    ).reshape(0, 4)
    inp = tensor(x)
    scores = tensor([], dtype=np.float32)
    for _ in range(3):
        result = fn(inp, scores=scores)
        np.testing.assert_equal(result.numpy(), np.array([], dtype=np.int32))