def test_distributed_evaluate_tf_example(self): distributed_train_and_evaluate( 1, _model_zoo_path, "test_module.custom_model", training=False, dataset_name=DatasetName.TEST_MODULE, callback_classes=[CheckRetryCallback], )
def test_cifar10_evaluate(self): model_defs = [ "cifar10_functional_api.cifar10_functional_api.custom_model", "cifar10_subclass.cifar10_subclass.CustomModel", ] for model_def in model_defs: distributed_train_and_evaluate( [32, 32, 3], _model_zoo_path, model_def, training=False )
def test_mnist_evaluate(self): model_defs = [ "mnist_functional_api.mnist_functional_api.custom_model", "mnist_subclass.mnist_subclass.CustomModel", ] for model_def in model_defs: distributed_train_and_evaluate( [28, 28], _model_zoo_path, model_def, training=False )
def test_resnet50_subclass_evaluate(self): distributed_train_and_evaluate( [224, 224, 3], _model_zoo_path, "resnet50_subclass.resnet50_subclass.CustomModel", model_params='num_classes=10;dtype="float32"', training=False, dataset_name=DatasetName.IMAGENET, )
def test_distributed_train_get_model_steps(self): distributed_train_and_evaluate( 1, _model_zoo_path, "test_module.custom_model", training=True, dataset_name=DatasetName.TEST_MODULE, callback_classes=[CheckWorkerModelCallback], use_async=True, get_model_steps=4, )
def test_deepfm_functional_evaluate(self): model_params = ( "input_dim=5383;embedding_dim=4;input_length=10;fc_unit=4") distributed_train_and_evaluate( 10, _model_zoo_path, "deepfm_functional_api.deepfm_functional_api.custom_model", model_params=model_params, training=False, dataset_name=DatasetName.FRAPPE, )
def test_resnet50_subclass_train(self): use_asyncs = [False, True] for use_async in use_asyncs: distributed_train_and_evaluate( [224, 224, 3], _model_zoo_path, "resnet50_subclass.resnet50_subclass.CustomModel", training=True, dataset_name=DatasetName.IMAGENET, use_async=use_async, )
def test_deepfm_functional_train(self): model_params = ( "input_dim=5383;embedding_dim=4;input_length=10;fc_unit=4") use_asyncs = [False, True] for use_async in use_asyncs: distributed_train_and_evaluate( 10, _model_zoo_path, "deepfm_functional_api.deepfm_functional_api.custom_model", model_params=model_params, training=True, dataset_name=DatasetName.FRAPPE, use_async=use_async, )
def _test_evaluate( self, feature_shape, model_def, model_params="", dataset_name=DatasetName.IMAGE_DEFAULT, ): num_ps_pods = 2 grads_to_wait = 1 _, ps_channels, pservers = create_pserver(_model_zoo_path, model_def, grads_to_wait, False, num_ps_pods) try: model_version = distributed_train_and_evaluate( feature_shape, _model_zoo_path, model_def, model_params=model_params, training=False, dataset_name=dataset_name, ps_channels=ps_channels, pservers=pservers, ) finally: for pserver in pservers: pserver.server.stop(0) return model_version
def test_cifar10_train(self): model_defs = [ "cifar10_functional_api.cifar10_functional_api.custom_model", "cifar10_subclass.cifar10_subclass.CustomModel", ] use_asyncs = [False, True] configs = list(itertools.product(model_defs, use_asyncs)) model_versions = [] for config in configs: model_version = distributed_train_and_evaluate( [32, 32, 3], _model_zoo_path, config[0], training=True, use_async=config[1], ) model_versions.append(model_version) # async model version = sync model version * 2 self.assertEqual(model_versions[0] * 2, model_versions[1]) self.assertEqual(model_versions[2] * 2, model_versions[3])
def _test_train( self, feature_shape, model_def, model_params="", dataset_name=DatasetName.IMAGE_DEFAULT, ): num_ps_pods = 2 use_asyncs = [False, True] model_versions = [] for use_async in use_asyncs: grads_to_wait = 1 if use_async else 2 _, ps_channels, pservers = create_pserver( _model_zoo_path, model_def, grads_to_wait, use_async, num_ps_pods, ) try: model_version = distributed_train_and_evaluate( feature_shape, _model_zoo_path, model_def, model_params=model_params, training=True, dataset_name=dataset_name, use_async=use_async, ps_channels=ps_channels, pservers=pservers, ) finally: for pserver in pservers: pserver.server.stop(0) for channel in ps_channels: channel.close() model_versions.append(model_version) return model_versions
def test_mnist_train(self): # TODO(qijun) need to rewite `distributed_train_and_evaluate` return model_defs = [ "mnist_functional_api.mnist_functional_api.custom_model", "mnist_subclass.mnist_subclass.CustomModel", ] use_asyncs = [False, True] configs = list(itertools.product(model_defs, use_asyncs)) model_versions = [] for config in configs: model_version = distributed_train_and_evaluate( [28, 28], _model_zoo_path, config[0], training=True, use_async=config[1], ) model_versions.append(model_version) # async model version = sync model version * 2 self.assertEqual(model_versions[0] * 2, model_versions[1]) self.assertEqual(model_versions[2] * 2, model_versions[3])