예제 #1
0
    def setUp(self):
        self.loggers = {
            "online": CometLoggerCallback(),
            "offline": CometLoggerCallback(online=False)
        }

        self.trial = MockTrial({"p1": 1}, "trial_1", 1, "artifact")
예제 #2
0
    def test_kwargs_passthrough(self, experiment):
        """Test that additional keyword arguments to CometLoggerCallback get
        passed through to comet_ml.Experiment on log_trial_start
        """
        experiment_kwargs = {"kwarg_1": "val_1"}
        logger = CometLoggerCallback(**experiment_kwargs)
        trial = MockTrial({"parameter": 1}, "trial2", 1, "artifact")
        logger.log_trial_start(trial)

        # These are the default kwargs that get passed to create the experiment
        expected_kwargs = {kwarg: False for kwarg in logger._exclude_autolog}
        expected_kwargs.update(experiment_kwargs)

        experiment.assert_called_with(**expected_kwargs)
예제 #3
0
 def setUp(self):
     self.logger = CometLoggerCallback()
     self.trials = [
         MockTrial({"p1": 1}, "trial_1", 1, "artifact"),
         MockTrial({"p1": 2}, "trial_2", 2, "artifact"),
         MockTrial({"p1": 2}, "trial_3", 3, "artifact"),
     ]
예제 #4
0
def tune_function(api_key=None, project_name=None):
    analysis = tune.run(
        train_function,
        name="comet",
        metric="loss",
        mode="min",
        callbacks=[
            CometLoggerCallback(api_key=api_key,
                                project_name=project_name,
                                tags=["comet_example"])
        ],
        config={
            "mean": tune.grid_search([1, 2, 3]),
            "sd": tune.uniform(0.2, 0.8)
        },
    )
    return analysis.best_config
예제 #5
0
    def test_configure_experiment_defaults(self):
        """Test CometLoggerCallback._configure_experiment_defaults."""
        logger = self.logger

        # Test that autologging features are properly disabled
        exclude = CometLoggerCallback._exclude_autolog
        for option in exclude:
            self.assertFalse(logger.experiment_kwargs.get(option))
        del logger

        # Don't disable logging if user overwrites defaults by passing in args
        for include_option in exclude:
            # This unpacks to become e.g. CometLoggerCallback(log_env_cpu=True)
            logger = CometLoggerCallback(**{include_option: True})
            for option in exclude:
                if option == include_option:
                    self.assertTrue(logger.experiment_kwargs.get(option))
                else:
                    self.assertFalse(logger.experiment_kwargs.get(option))
예제 #6
0
def train_model(train_dataset: ray.data.Dataset, comet_project: str) -> Result:
    """Train a simple XGBoost model and return the result."""
    trainer = XGBoostTrainer(
        scaling_config={"num_workers": 2},
        params={"tree_method": "auto"},
        label_column="target",
        datasets={"train": train_dataset},
        num_boost_round=10,
        run_config=RunConfig(callbacks=[
            # This is the part needed to enable logging to Comet ML.
            # It assumes Comet ML can find a valid API (e.g. by setting
            # the ``COMET_API_KEY`` environment variable).
            CometLoggerCallback(
                project_name=comet_project,
                save_checkpoints=True,
            )
        ]),
    )
    result = trainer.fit()
    return result
예제 #7
0
 def setUp(self):
     self.logger = CometLoggerCallback()
     self.trials = [
         MockTrial({"p1": 1}, "trial_1", 1, "artifact"),
         MockTrial({"p1": 2}, "trial_2", 1, "artifact")
     ]
     self.result = {
         "config": {
             "p1": 1
         },
         "node_ip": "0.0.0.0",
         "hostname": "hostname_val",
         "pid": "1234",
         "date": "2000-01-01",
         "experiment_id": "1234",
         "trial_id": 1,
         "experiment_tag": "tag1",
         "hist_stats/episode_reward": [1, 0, 1, -1, 0, 1],
         "hist_stats/episode_lengths": [1, 2, 3, 4, 5, 6],
         "metric1": 0.8,
         "metric2": 1,
         "metric3": None,
         "training_iteration": 0
     }
예제 #8
0
 def setUp(self):
     self.logger = CometLoggerCallback()