コード例 #1
0
def test_multiply():
    out_dir = os.path.join(BASE_DIR, "multiply", "out")
    expand = prepend_dirname_fn(out_dir)
    align_tensors([
        TensorAlignmentData(
            "multiply_out",
            expand("ff_out.pt"),
            expand("torch_out.pt"),
        ),
        TensorAlignmentData(
            "multiply_out_grad",
            expand("ff_out_grad.pt"),
            expand("torch_out_grad.pt"),
        ),
    ])
コード例 #2
0
def test_getitem():
    out_dir = os.path.join(BASE_DIR, "getitem", "out")
    expand = prepend_dirname_fn(out_dir)
    align_tensors([
        TensorAlignmentData(
            "getitem_out",
            expand("ff_out.pt"),
            expand("torch_out.pt"),
        ),
    ])
コード例 #3
0
def test_view_embedding():
    out_dir = os.path.join(BASE_DIR, "view_embedding", "out")
    expand = prepend_dirname_fn(out_dir)
    align_tensors([
        TensorAlignmentData(
            "embedding_out",
            expand("ff_out.pt"),
            expand("torch_out.pt"),
        ),
        TensorAlignmentData(
            "embedding_out_grad",
            expand("ff_out_grad.pt"),
            expand("torch_out_grad.pt"),
        ),
        TensorAlignmentData(
            "embedding_weight_grad",
            expand("ff_weight_grad.pt"),
            expand("torch_weight_grad.pt"),
        ),
    ])
コード例 #4
0
def test_linear():
    out_dir = os.path.join(BASE_DIR, "linear", "out")
    expand = prepend_dirname_fn(out_dir)
    align_tensors([
        TensorAlignmentData(
            "linear_out",
            expand("ff_out.pt"),
            expand("torch_out.pt"),
        ),
        TensorAlignmentData(
            "linear_out_grad",
            expand("ff_out_grad.pt"),
            expand("torch_out_grad.pt"),
        ),
        TensorAlignmentData(
            "linear_weight_grad",
            expand("ff_weight_grad.pt"),
            expand("torch_weight_grad.pt"),
        ),
        TensorAlignmentData("linear_bias_grad", expand("ff_bias_grad.pt"),
                            expand("torch_bias_grad.pt"))
    ])