def test_pytorch_parity(tmpdir, cls_model, max_diff):
    """
    Verify that the same  pytorch and lightning models achieve the same results
    """
    num_epochs = 4
    num_rums = 3
    lightning_outs, pl_times = lightning_loop(cls_model, num_rums, num_epochs)
    manual_outs, pt_times = vanilla_loop(cls_model, num_rums, num_epochs)

    # make sure the losses match exactly  to 5 decimal places
    for pl_out, pt_out in zip(lightning_outs, manual_outs):
        np.testing.assert_almost_equal(pl_out, pt_out, 5)

    # the fist run initialize dataset (download & filter)
    tutils.assert_speed_parity_absolute(pl_times[1:], pt_times[1:],
                                        nb_epochs=num_epochs, max_diff=max_diff)
def test_pytorch_parity(tmpdir,
                        cls_model,
                        max_diff: float,
                        num_epochs: int = 4,
                        num_runs: int = 3):
    """
    Verify that the same  pytorch and lightning models achieve the same results
    """
    lightning = lightning_loop(cls_model, num_runs, num_epochs)
    vanilla = vanilla_loop(cls_model, num_runs, num_epochs)

    # make sure the losses match exactly  to 5 decimal places
    for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']):
        np.testing.assert_almost_equal(pl_out, pt_out, 5)

    # the fist run initialize dataset (download & filter)
    tutils.assert_speed_parity_absolute(lightning['durations'][1:],
                                        vanilla['durations'][1:],
                                        nb_epochs=num_epochs,
                                        max_diff=max_diff)