コード例 #1
0

# __trainable_end__

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--smoke-test",
                        action="store_true",
                        help="Finish quickly for testing")
    args, _ = parser.parse_known_args()

    ray.init()
    datasets.MNIST("~/data", train=True, download=True)

    # check if PytorchTrainble will save/restore correctly before execution
    validate_save_restore(PytorchTrainble)
    validate_save_restore(PytorchTrainble, use_object_store=True)

    # __pbt_begin__
    scheduler = PopulationBasedTraining(
        time_attr="training_iteration",
        metric="mean_accuracy",
        mode="max",
        perturbation_interval=5,
        hyperparam_mutations={
            # distribution for resampling
            "lr": lambda: np.random.uniform(0.0001, 1),
            # allow perturbations within this set of categorical values
            "momentum": [0.8, 0.9, 0.99],
        })
コード例 #2
0
ファイル: test_tune_restore.py プロジェクト: zlpmichelle/ray
 def testAsyncHyperbandExample(self):
     from ray.tune.examples.async_hyperband_example import MyTrainableClass
     validate_save_restore(MyTrainableClass)
     validate_save_restore(MyTrainableClass, use_object_store=True)
コード例 #3
0
ファイル: test_tune_restore.py プロジェクト: zlpmichelle/ray
 def testLogging(self):
     from ray.tune.examples.logging_example import MyTrainableClass
     validate_save_restore(MyTrainableClass)
     validate_save_restore(MyTrainableClass, use_object_store=True)
コード例 #4
0
ファイル: test_tune_restore.py プロジェクト: zlpmichelle/ray
 def testPyTorchMNIST(self):
     from ray.tune.examples.mnist_pytorch_trainable import TrainMNIST
     from torchvision import datasets
     datasets.MNIST("~/data", train=True, download=True)
     validate_save_restore(TrainMNIST)
     validate_save_restore(TrainMNIST, use_object_store=True)
コード例 #5
0
ファイル: test_tune_restore.py プロジェクト: zlpmichelle/ray
 def testPBTKeras(self):
     from ray.tune.examples.pbt_tune_cifar10_with_keras import Cifar10Model
     from tensorflow.python.keras.datasets import cifar10
     cifar10.load_data()
     validate_save_restore(Cifar10Model)
     validate_save_restore(Cifar10Model, use_object_store=True)
コード例 #6
0
ファイル: test_tune_restore.py プロジェクト: xuhuazhe/ray
 def testTensorFlowMNIST(self):
     from ray.tune.examples.tune_mnist_ray_hyperband import TrainMNIST
     validate_save_restore(TrainMNIST)
     validate_save_restore(TrainMNIST, use_object_store=True)