示例#1
0
    def build(self, hp: HyperParameters, *args: Any,
              **kwargs: Any) -> keras.Model:
        for i in range(self._max_fail_streak + 1):
            # clean-up TF graph from previously stored (defunct) graph
            keras.backend.clear_session()
            gc.collect()

            # Build a model, allowing max_fail_streak failed attempts.
            try:
                with maybe_distribute(self.distribution_strategy):
                    # /!\ Below line is the only one changed compared to the original version: /!\
                    # model = self.hypermodel.build(hp)
                    model = self.hypermodel.build(hp, *args, **kwargs)
            except:  # noqa: E722 do not use bare 'except'
                if keras_tuner_config_module.DEBUG:
                    traceback.print_exc()

                print("Invalid model %s/%s" % (i, self._max_fail_streak))

                if i == self._max_fail_streak:
                    raise RuntimeError(
                        "Too many failed attempts to build model.")
                continue

            # Stop if `build()` does not return a valid model.
            if not isinstance(model, keras.models.Model):
                raise RuntimeError(
                    "Model-building function did not return "
                    "a valid Keras Model instance, found {}".format(model))

            # Check model size.
            size = maybe_compute_model_size(model)
            if self.max_model_size and size > self.max_model_size:
                print("Oversized model: %s parameters -- skipping" % (size))
                if i == self._max_fail_streak:
                    raise RuntimeError(
                        "Too many consecutive oversized models.")
                continue
            break

        return self._compile_model(model)
示例#2
0
 def get_best_model(self):
     with hm_module.maybe_distribute(self.distribution_strategy):
         model = tf.keras.models.load_model(self.best_model_path)
     return model
示例#3
0
 def get_best_model(self):
     model = self._build_best_model()
     with hm_module.maybe_distribute(self.distribution_strategy):
         model.load_weights(self.best_model_path)
     return model