def test_pytorch_parity(tmpdir, cls_model, max_diff): """ Verify that the same pytorch and lightning models achieve the same results """ num_epochs = 4 num_rums = 3 lightning_outs, pl_times = lightning_loop(cls_model, num_rums, num_epochs) manual_outs, pt_times = vanilla_loop(cls_model, num_rums, num_epochs) # make sure the losses match exactly to 5 decimal places for pl_out, pt_out in zip(lightning_outs, manual_outs): np.testing.assert_almost_equal(pl_out, pt_out, 5) # the fist run initialize dataset (download & filter) tutils.assert_speed_parity_absolute(pl_times[1:], pt_times[1:], nb_epochs=num_epochs, max_diff=max_diff)
def test_pytorch_parity(tmpdir, cls_model, max_diff: float, num_epochs: int = 4, num_runs: int = 3): """ Verify that the same pytorch and lightning models achieve the same results """ lightning = lightning_loop(cls_model, num_runs, num_epochs) vanilla = vanilla_loop(cls_model, num_runs, num_epochs) # make sure the losses match exactly to 5 decimal places for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']): np.testing.assert_almost_equal(pl_out, pt_out, 5) # the fist run initialize dataset (download & filter) tutils.assert_speed_parity_absolute(lightning['durations'][1:], vanilla['durations'][1:], nb_epochs=num_epochs, max_diff=max_diff)