Exemplo n.º 1
0
 def __init__(self, model_creator, optimizer_creator, loss_creator):
     from zoo.orca.automl.pytorch_utils import validate_pytorch_loss, validate_pytorch_optim
     self.model_creator = model_creator
     optimizer = validate_pytorch_optim(optimizer_creator)
     self.optimizer_creator = optimizer
     loss = validate_pytorch_loss(loss_creator)
     self.loss_creator = loss
Exemplo n.º 2
0
    def from_torch(
        *,
        model_creator,
        optimizer,
        loss,
        logs_dir="/tmp/auto_estimator_logs",
        resources_per_trial=None,
        name=None,
    ):
        """
        Create an AutoEstimator for torch.

        :param model_creator: PyTorch model creator function.
        :param optimizer: PyTorch optimizer creator function or pytorch optimizer name (string).
            Note that you should specify learning rate search space with key as "lr" or LR_NAME
            (from zoo.orca.automl.pytorch_utils import LR_NAME) if input optimizer name.
            Without learning rate search space specified, the default learning rate value of 1e-3
            will be used for all estimators.
        :param loss: PyTorch loss instance or PyTorch loss creator function
            or pytorch loss name (string).
        :param logs_dir: Local directory to save logs and results. It defaults to
            "/tmp/auto_estimator_logs"
        :param resources_per_trial: Dict. resources for each trial. e.g. {"cpu": 2}.
        :param name: Name of the auto estimator.

        :return: an AutoEstimator object.
        """
        from zoo.orca.automl.pytorch_utils import validate_pytorch_loss, \
            validate_pytorch_optim
        from zoo.automl.model import PytorchModelBuilder
        loss = validate_pytorch_loss(loss)
        optimizer = validate_pytorch_optim(optimizer)
        model_builder = PytorchModelBuilder(model_creator=model_creator,
                                            optimizer_creator=optimizer,
                                            loss_creator=loss)

        return AutoEstimator(model_builder=model_builder,
                             logs_dir=logs_dir,
                             resources_per_trial=resources_per_trial,
                             name=name)