Beispiel #1
0
 def test_train(self, hub_model, benchmark):
     try:
         hub_model.set_train()
         benchmark(hub_model.train)
         benchmark.extra_info['machine_state'] = get_machine_state()
     except NotImplementedError:
         print('Method train is not implemented, skipping...')
Beispiel #2
0
 def test_eval(self, hub_model, benchmark, pytestconfig):
     try:
         ng_flag = hub_model.eval_in_nograd(
         ) and not pytestconfig.getoption("disable_nograd")
         with no_grad(ng_flag):
             hub_model.set_eval()
             benchmark(hub_model.eval)
             benchmark.extra_info['machine_state'] = get_machine_state()
     except NotImplementedError:
         print('Method eval is not implemented, skipping...')
Beispiel #3
0
    def test_train(self, model_path, device, compiler, benchmark):
        try:
            task = ModelTask(model_path)
            if not task.model_details.exists:
                return  # Model is not supported.

            task.make_model_instance(device=device, jit=(compiler == 'jit'))
            task.set_train()
            benchmark(task.train)
            benchmark.extra_info['machine_state'] = get_machine_state()

        except NotImplementedError:
            print('Method eval is not implemented, skipping...')
Beispiel #4
0
    def test_eval(self, model_path, device, compiler, benchmark, pytestconfig):
        try:
            task = ModelTask(model_path)
            if not task.model_details.exists:
                return  # Model is not supported.

            task.make_model_instance(device=device, jit=(compiler == 'jit'))

            with task.no_grad(
                    disable_nograd=pytestconfig.getoption("disable_nograd")):
                task.set_eval()
                benchmark(task.eval)
                benchmark.extra_info['machine_state'] = get_machine_state()
                if pytestconfig.getoption("check_opt_vs_noopt_jit"):
                    task.check_opt_vs_noopt_jit()

        except NotImplementedError:
            print('Method eval is not implemented, skipping...')