def global_function_or_identity(*args, **kwargs): if rt_mode.CurrentMode() == rt_mode.NORMAL_MODE: return flow.global_function(*args, **kwargs) else: assert rt_mode.CurrentMode() == rt_mode.GLOBAL_MODE identity_decorator = lambda func: func return identity_decorator
def test_2n8c(test_case): flow.config.gpu_device_num(4) pretrain_job = flow.global_function(func_config)(PretrainJob) check_point = flow.train.CheckPoint() check_point.load(FLAGS.model_load_dir) of_loss = [pretrain_job().get().mean() for _ in range(10)] print(of_loss)
def test_1n1c(test_case): flow.config.enable_debug_mode(True) flow.config.gpu_device_num(1) pretrain_job = flow.global_function( type="train", function_config=func_config)(PretrainJob) check_point = flow.train.CheckPoint() check_point.load(FLAGS.model_load_dir) of_loss = [pretrain_job().get().mean() for _ in range(10)] print(of_loss)