示例#1
0
    def learn_with_optimizer(
        self, evaluator: Optional[Evaluator], trial: Optional[Trial]
    ) -> None:
        self.start_learning()
        best_score = -float("inf")
        n_score_degradation = 0
        pb = progress_bar(range(self.max_epoch))
        for epoch in pb:
            self.run_epoch()
            if (epoch + 1) % self.validate_epoch:
                continue

            if evaluator is None:
                continue

            target_score = evaluator.get_target_score(self)

            pb.comment = f"{evaluator.target_metric_name}={target_score}"
            relevant_score = target_score
            if relevant_score > best_score:
                best_score = relevant_score
                self.save_state()
                self.learnt_config["max_epoch"] = epoch + 1
                n_score_degradation = 0
            else:
                n_score_degradation += 1
                if n_score_degradation >= self.score_degradation_max:
                    break
            if trial is not None:
                trial.report(-relevant_score, epoch)
                if trial.should_prune():
                    raise exceptions.TrialPruned()

        if evaluator is not None:
            self.load_state()
示例#2
0
def _train(trial, dataset, model_args, train_args):
    for output, step, completed, models in samme.train(
            dataset, model_args, train_args):
        if completed:
            break
        trial.report(output, step)
        if trial.should_prune():
            raise exceptions.TrialPruned()
    return output, models
示例#3
0
    def _cross_validate_with_pruning(
            self,
            trial,  # type: trial_module.Trial
            estimator,  # type: BaseEstimator
    ):
        # type: (...) -> Dict[str, OneDimArrayLikeType]

        if is_classifier(estimator):
            partial_fit_params = self.fit_params.copy()
            classes = np.unique(self.y)

            partial_fit_params.setdefault("classes", classes)

        else:
            partial_fit_params = self.fit_params

        n_splits = self.cv.get_n_splits(self.X, self.y, groups=self.groups)
        estimators = [clone(estimator) for _ in range(n_splits)]
        scores = {
            "fit_time": np.zeros(n_splits),
            "score_time": np.zeros(n_splits),
            "test_score": np.empty(n_splits),
        }

        if self.return_train_score:
            scores["train_score"] = np.empty(n_splits)

        for step in range(self.max_iter):
            for i, (train, test) in enumerate(
                    self.cv.split(self.X, self.y, groups=self.groups)):
                out = self._partial_fit_and_score(estimators[i], train, test,
                                                  partial_fit_params)

                if self.return_train_score:
                    scores["train_score"][i] = out.pop(0)

                scores["test_score"][i] = out[0]
                scores["fit_time"][i] += out[1]
                scores["score_time"][i] += out[2]

            intermediate_value = np.nanmean(scores["test_score"])

            trial.report(intermediate_value, step=step)

            if trial.should_prune():
                self._store_scores(trial, scores)

                raise exceptions.TrialPruned(
                    "trial was pruned at iteration {}.".format(step))

        return scores
示例#4
0
    def learn_with_optimizer(self, evaluator: Optional[Evaluator],
                             trial: Optional[Trial]) -> None:
        self.start_learning()
        best_score = -float("inf")
        n_score_degradation = 0

        with tqdm(total=self.max_epoch) as progress_bar:
            for epoch in range(self.max_epoch):
                self.run_epoch()
                progress_bar.update(1)
                if (epoch + 1) % self.validate_epoch:
                    continue

                if evaluator is None:
                    continue

                target_score = evaluator.get_target_score(self)

                progress_bar.set_description(f"valid_score={target_score}")
                relevant_score = target_score
                if relevant_score > best_score:
                    best_score = relevant_score
                    self.save_state()
                    self.learnt_config["max_epoch"] = epoch + 1
                    n_score_degradation = 0
                else:
                    n_score_degradation += 1
                    if n_score_degradation >= self.score_degradation_max:
                        break
                if trial is not None:
                    trial.report(-relevant_score, epoch)
                    if trial.should_prune():
                        raise exceptions.TrialPruned()

            if evaluator is not None:
                self.load_state()
示例#5
0
    def _train_nn_with_trial(
        self,
        mlp: hk.Transformed,
        config: MLPTrainingConfig,
        trial: Optional[optuna.Trial] = None,
    ) -> Tuple[float, int]:

        rng_key = jax.random.PRNGKey(0)
        rng_key, sub_key = jax.random.split(rng_key)
        params = mlp.init(
            sub_key,
            jnp.zeros((1, self.profile_train.shape[1]), dtype=jnp.float32),
            False,
        )
        opt = optax.adam(config.learning_rate)
        opt_state = opt.init(params)

        rng_key, sub_key = jax.random.split(rng_key)

        @partial(jax.jit, static_argnums=(3, ))
        def predict(params: hk.Params, rng: PRNGKey, X: jnp.ndarray,
                    training: bool) -> jnp.ndarray:
            return mlp.apply(params, rng, X, training)

        @partial(jax.jit, static_argnums=(4, ))
        def loss_fn(
            params: hk.Params,
            rng: PRNGKey,
            X: jnp.ndarray,
            Y: jnp.ndarray,
            training: bool,
        ) -> jnp.ndarray:
            prediction = predict(params, rng, X, training)
            return ((Y - prediction)**2).mean(axis=1).sum()

        @jax.jit
        def update(
            params: hk.Params,
            rng: PRNGKey,
            opt_state: optax.OptState,
            X: jnp.ndarray,
            Y: jnp.ndarray,
        ) -> Tuple[jnp.ndarray, hk.Params, optax.OptState]:
            loss_value = loss_fn(params, rng, X, Y, True)
            grad = jax.grad(loss_fn)(params, rng, X, Y, True)
            updates, opt_state = opt.update(grad, opt_state)
            new_params = optax.apply_updates(params, updates)
            return loss_value, new_params, opt_state

        best_val_score = float("inf")
        n_epochs = 512
        mb_size = 128
        score_degradation_count = 0
        val_score_degradation_max = 10
        best_epoch = 0
        for epoch in tqdm(range(n_epochs)):
            train_loss = 0
            for X_mb, y_mb, _ in self.stream(self.profile_train,
                                             self.embedding_train, mb_size):
                rng_key, sub_key = jax.random.split(rng_key)

                loss_value, params, opt_state = update(params, sub_key,
                                                       opt_state, X_mb, y_mb)
                train_loss += loss_value
            train_loss /= self.profile_train.shape[0]

            val_loss = 0
            for X_mb, y_mb, size in self.stream(self.profile_test,
                                                self.embedding_test,
                                                mb_size,
                                                shuffle=False):
                val_loss += loss_fn(params, rng_key, X_mb, y_mb,
                                    False)  # rng key will not be used
            val_loss /= self.profile_test.shape[0]
            if trial is not None:
                trial.report(val_loss, epoch)
                if trial.should_prune():
                    raise exceptions.TrialPruned()

            if val_loss < best_val_score:
                best_epoch = epoch + 1
                best_val_score = val_loss
                score_degradation_count = 0
            else:
                score_degradation_count += 1

            if score_degradation_count >= val_score_degradation_max:
                break

        return best_val_score, best_epoch