def test_convert_linear_layer_all():
    node = onnx.helper.make_node(
        "Gemm",
        inputs=["a", "b", "c"],
        outputs=["y"],
        alpha=0.25,
        beta=0.35,
        transA=0,
        transB=1,
    )
    a = np.random.ranf([4, 3]).astype(np.float32).transpose()
    b = np.random.ranf([5, 4]).astype(np.float32)
    c = np.random.ranf([1, 5]).astype(np.float32)
    y = gemm_reference_implementation(a,
                                      b,
                                      c,
                                      transA=0,
                                      transB=1,
                                      alpha=0.25,
                                      beta=0.35)

    params = [numpy_helper.from_array(b), numpy_helper.from_array(c)]
    op = convert_linear_layer(node, params)
    op.eval()
    out = op(torch.from_numpy(a))
    torch.allclose(torch.from_numpy(y), out)
def test_convert_linear_layer_default():
    node = onnx.helper.make_node("Gemm", inputs=["a", "b", "c"], outputs=["y"])
    a = np.random.ranf([3, 6]).astype(np.float32)
    b = np.random.ranf([6, 4]).astype(np.float32)
    c = np.random.ranf([3, 4]).astype(np.float32)
    y = gemm_reference_implementation(a, b, c)

    params = [numpy_helper.from_array(b), numpy_helper.from_array(c)]
    op = convert_linear_layer(node, params)
    op.eval()
    out = op(torch.from_numpy(a))
    torch.allclose(torch.from_numpy(y), out)