예제 #1
0
 def test_non_ints(self):
     with six.assertRaisesRegex(self, TypeError,
                                "min_value must be an int: -inf"):
         hp.IntInterval(float("-inf"), 0)
     with six.assertRaisesRegex(self, TypeError,
                                "max_value must be an int: 'eleven'"):
         hp.IntInterval(7, "eleven")
예제 #2
0
def _create_hparams_config(searchspace):
    hparams = []

    for key, val in searchspace.names().items():
        if val == "DOUBLE":
            hparams.append(
                hp.HParam(
                    key,
                    hp.RealInterval(
                        float(searchspace.get(key)[0]), float(searchspace.get(key)[1])
                    ),
                )
            )
        elif val == "INTEGER":
            hparams.append(
                hp.HParam(
                    key,
                    hp.IntInterval(searchspace.get(key)[0], searchspace.get(key)[1]),
                )
            )
        elif val == "DISCRETE":
            hparams.append(hp.HParam(key, hp.Discrete(searchspace.get(key))))
        elif val == "CATEGORICAL":
            hparams.append(hp.HParam(key, hp.Discrete(searchspace.get(key))))

    return hparams
예제 #3
0
 def test_sample_uniform_unseeded(self):
     domain = hp.IntInterval(2, 7)
     # Note: `randint` samples from a closed interval, which is what we
     # want (as opposed to `randrange`).
     with mock.patch.object(random, "randint") as m:
         sentinel = object()
         m.return_value = sentinel
         result = domain.sample_uniform()
     self.assertIs(result, sentinel)
     m.assert_called_once_with(2, 7)
예제 #4
0
 def test_sample_uniform(self):
     domain = hp.IntInterval(2, 7)
     rng = mock.Mock()
     sentinel = object()
     # Note: `randint` samples from a closed interval, which is what we
     # want (as opposed to `randrange`).
     rng.randint.return_value = sentinel
     result = domain.sample_uniform(rng)
     self.assertIs(result, sentinel)
     rng.randint.assert_called_once_with(2, 7)
예제 #5
0
 def _initialize_model(self, writer):
   HP_DENSE_NEURONS = hp.HParam("dense_neurons", hp.IntInterval(4, 16))
   self.hparams = {
       "optimizer": "adam",
       HP_DENSE_NEURONS: 8,
   }
   self.model = tf.keras.models.Sequential([
       tf.keras.layers.Dense(self.hparams[HP_DENSE_NEURONS], input_shape=(1,)),
       tf.keras.layers.Dense(1, activation="sigmoid"),
   ])
   self.model.compile(loss="mse", optimizer=self.hparams["optimizer"])
   self.callback = keras.Callback(writer, self.hparams)
예제 #6
0
    def setUp(self):
        self.logdir = os.path.join(self.get_temp_dir(), "logs")
        self.hparams = {
            hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)):
            0.02,
            hp.HParam("dense_layers", hp.IntInterval(2, 7)):
            5,
            hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])):
            "adam",
            hp.HParam("who_knows_what"):
            "???",
            hp.HParam(
                "magic",
                hp.Discrete([False, True]),
                display_name="~*~ Magic ~*~",
                description="descriptive",
            ):
            True,
            "dropout":
            0.3,
        }
        self.normalized_hparams = {
            "learning_rate": 0.02,
            "dense_layers": 5,
            "optimizer": "adam",
            "who_knows_what": "???",
            "magic": True,
            "dropout": 0.3,
        }
        self.start_time_secs = 123.45
        self.trial_id = "psl27"

        self.expected_session_start_pb = plugin_data_pb2.SessionStartInfo()
        text_format.Merge(
            """
            hparams { key: "learning_rate" value { number_value: 0.02 } }
            hparams { key: "dense_layers" value { number_value: 5 } }
            hparams { key: "optimizer" value { string_value: "adam" } }
            hparams { key: "who_knows_what" value { string_value: "???" } }
            hparams { key: "magic" value { bool_value: true } }
            hparams { key: "dropout" value { number_value: 0.3 } }
            """,
            self.expected_session_start_pb,
        )
        self.expected_session_start_pb.group_name = self.trial_id
        self.expected_session_start_pb.start_time_secs = self.start_time_secs
예제 #7
0
 def test_backward_endpoints(self):
     with six.assertRaisesRegex(self, ValueError, "123 > 45"):
         hp.IntInterval(123, 45)
예제 #8
0
 def test_singleton_domain(self):
     domain = hp.IntInterval(61, 61)
     self.assertEqual(domain.min_value, 61)
     self.assertEqual(domain.max_value, 61)
     self.assertEqual(domain.dtype, int)
예제 #9
0
 def test_simple(self):
     domain = hp.IntInterval(3, 7)
     self.assertEqual(domain.min_value, 3)
     self.assertEqual(domain.max_value, 7)
     self.assertEqual(domain.dtype, int)
예제 #10
0
    def setUp(self):
        self.logdir = os.path.join(self.get_temp_dir(), "logs")

        self.hparams = [
            hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)),
            hp.HParam("dense_layers", hp.IntInterval(2, 7)),
            hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])),
            hp.HParam("who_knows_what"),
            hp.HParam(
                "magic",
                hp.Discrete([False, True]),
                display_name="~*~ Magic ~*~",
                description="descriptive",
            ),
        ]
        self.metrics = [
            hp.Metric("samples_per_second"),
            hp.Metric(group="train",
                      tag="batch_loss",
                      display_name="loss (train)"),
            hp.Metric(
                group="validation",
                tag="epoch_accuracy",
                display_name="accuracy (val.)",
                description="Accuracy on the _validation_ dataset.",
                dataset_type=hp.Metric.VALIDATION,
            ),
        ]
        self.time_created_secs = 1555624767.0

        self.expected_experiment_pb = api_pb2.Experiment()
        text_format.Merge(
            """
            time_created_secs: 1555624767.0
            hparam_infos {
              name: "learning_rate"
              type: DATA_TYPE_FLOAT64
              domain_interval {
                min_value: 0.01
                max_value: 0.1
              }
            }
            hparam_infos {
              name: "dense_layers"
              type: DATA_TYPE_FLOAT64
              domain_interval {
                min_value: 2
                max_value: 7
              }
            }
            hparam_infos {
              name: "optimizer"
              type: DATA_TYPE_STRING
              domain_discrete {
                values {
                  string_value: "adam"
                }
                values {
                  string_value: "sgd"
                }
              }
            }
            hparam_infos {
              name: "who_knows_what"
            }
            hparam_infos {
              name: "magic"
              type: DATA_TYPE_BOOL
              display_name: "~*~ Magic ~*~"
              description: "descriptive"
              domain_discrete {
                values {
                  bool_value: false
                }
                values {
                  bool_value: true
                }
              }
            }
            metric_infos {
              name {
                tag: "samples_per_second"
              }
            }
            metric_infos {
              name {
                group: "train"
                tag: "batch_loss"
              }
              display_name: "loss (train)"
            }
            metric_infos {
              name {
                group: "validation"
                tag: "epoch_accuracy"
              }
              display_name: "accuracy (val.)"
              description: "Accuracy on the _validation_ dataset."
              dataset_type: DATASET_VALIDATION
            }
            """,
            self.expected_experiment_pb,
        )