Exemplo n.º 1
0
 def test_distributed_evaluate_tf_example(self):
     distributed_train_and_evaluate(
         1,
         _model_zoo_path,
         "test_module.custom_model",
         training=False,
         dataset_name=DatasetName.TEST_MODULE,
         callback_classes=[CheckRetryCallback],
     )
Exemplo n.º 2
0
 def test_cifar10_evaluate(self):
     model_defs = [
         "cifar10_functional_api.cifar10_functional_api.custom_model",
         "cifar10_subclass.cifar10_subclass.CustomModel",
     ]
     for model_def in model_defs:
         distributed_train_and_evaluate(
             [32, 32, 3], _model_zoo_path, model_def, training=False
         )
Exemplo n.º 3
0
 def test_mnist_evaluate(self):
     model_defs = [
         "mnist_functional_api.mnist_functional_api.custom_model",
         "mnist_subclass.mnist_subclass.CustomModel",
     ]
     for model_def in model_defs:
         distributed_train_and_evaluate(
             [28, 28], _model_zoo_path, model_def, training=False
         )
Exemplo n.º 4
0
 def test_resnet50_subclass_evaluate(self):
     distributed_train_and_evaluate(
         [224, 224, 3],
         _model_zoo_path,
         "resnet50_subclass.resnet50_subclass.CustomModel",
         model_params='num_classes=10;dtype="float32"',
         training=False,
         dataset_name=DatasetName.IMAGENET,
     )
Exemplo n.º 5
0
 def test_distributed_train_get_model_steps(self):
     distributed_train_and_evaluate(
         1,
         _model_zoo_path,
         "test_module.custom_model",
         training=True,
         dataset_name=DatasetName.TEST_MODULE,
         callback_classes=[CheckWorkerModelCallback],
         use_async=True,
         get_model_steps=4,
     )
Exemplo n.º 6
0
 def test_deepfm_functional_evaluate(self):
     model_params = (
         "input_dim=5383;embedding_dim=4;input_length=10;fc_unit=4")
     distributed_train_and_evaluate(
         10,
         _model_zoo_path,
         "deepfm_functional_api.deepfm_functional_api.custom_model",
         model_params=model_params,
         training=False,
         dataset_name=DatasetName.FRAPPE,
     )
Exemplo n.º 7
0
 def test_resnet50_subclass_train(self):
     use_asyncs = [False, True]
     for use_async in use_asyncs:
         distributed_train_and_evaluate(
             [224, 224, 3],
             _model_zoo_path,
             "resnet50_subclass.resnet50_subclass.CustomModel",
             training=True,
             dataset_name=DatasetName.IMAGENET,
             use_async=use_async,
         )
Exemplo n.º 8
0
 def test_deepfm_functional_train(self):
     model_params = (
         "input_dim=5383;embedding_dim=4;input_length=10;fc_unit=4")
     use_asyncs = [False, True]
     for use_async in use_asyncs:
         distributed_train_and_evaluate(
             10,
             _model_zoo_path,
             "deepfm_functional_api.deepfm_functional_api.custom_model",
             model_params=model_params,
             training=True,
             dataset_name=DatasetName.FRAPPE,
             use_async=use_async,
         )
Exemplo n.º 9
0
 def _test_evaluate(
     self,
     feature_shape,
     model_def,
     model_params="",
     dataset_name=DatasetName.IMAGE_DEFAULT,
 ):
     num_ps_pods = 2
     grads_to_wait = 1
     _, ps_channels, pservers = create_pserver(_model_zoo_path, model_def,
                                               grads_to_wait, False,
                                               num_ps_pods)
     try:
         model_version = distributed_train_and_evaluate(
             feature_shape,
             _model_zoo_path,
             model_def,
             model_params=model_params,
             training=False,
             dataset_name=dataset_name,
             ps_channels=ps_channels,
             pservers=pservers,
         )
     finally:
         for pserver in pservers:
             pserver.server.stop(0)
     return model_version
Exemplo n.º 10
0
    def test_cifar10_train(self):
        model_defs = [
            "cifar10_functional_api.cifar10_functional_api.custom_model",
            "cifar10_subclass.cifar10_subclass.CustomModel",
        ]
        use_asyncs = [False, True]
        configs = list(itertools.product(model_defs, use_asyncs))

        model_versions = []
        for config in configs:
            model_version = distributed_train_and_evaluate(
                [32, 32, 3],
                _model_zoo_path,
                config[0],
                training=True,
                use_async=config[1],
            )
            model_versions.append(model_version)
        # async model version = sync model version * 2
        self.assertEqual(model_versions[0] * 2, model_versions[1])
        self.assertEqual(model_versions[2] * 2, model_versions[3])
Exemplo n.º 11
0
 def _test_train(
     self,
     feature_shape,
     model_def,
     model_params="",
     dataset_name=DatasetName.IMAGE_DEFAULT,
 ):
     num_ps_pods = 2
     use_asyncs = [False, True]
     model_versions = []
     for use_async in use_asyncs:
         grads_to_wait = 1 if use_async else 2
         _, ps_channels, pservers = create_pserver(
             _model_zoo_path,
             model_def,
             grads_to_wait,
             use_async,
             num_ps_pods,
         )
         try:
             model_version = distributed_train_and_evaluate(
                 feature_shape,
                 _model_zoo_path,
                 model_def,
                 model_params=model_params,
                 training=True,
                 dataset_name=dataset_name,
                 use_async=use_async,
                 ps_channels=ps_channels,
                 pservers=pservers,
             )
         finally:
             for pserver in pservers:
                 pserver.server.stop(0)
             for channel in ps_channels:
                 channel.close()
         model_versions.append(model_version)
     return model_versions
Exemplo n.º 12
0
    def test_mnist_train(self):
        # TODO(qijun) need to rewite `distributed_train_and_evaluate`
        return
        model_defs = [
            "mnist_functional_api.mnist_functional_api.custom_model",
            "mnist_subclass.mnist_subclass.CustomModel",
        ]
        use_asyncs = [False, True]
        configs = list(itertools.product(model_defs, use_asyncs))

        model_versions = []
        for config in configs:
            model_version = distributed_train_and_evaluate(
                [28, 28],
                _model_zoo_path,
                config[0],
                training=True,
                use_async=config[1],
            )
            model_versions.append(model_version)
        # async model version = sync model version * 2
        self.assertEqual(model_versions[0] * 2, model_versions[1])
        self.assertEqual(model_versions[2] * 2, model_versions[3])