Пример #1
0
    def _cv(self, dataset: tk.data.Dataset,
            folds: tk.validation.FoldsType) -> None:
        evals_list = []
        score_weights = []
        self.estimators_ = []
        for fold, (train_set,
                   val_set) in tk.utils.tqdm(enumerate(dataset.iter(folds)),
                                             total=len(folds),
                                             desc="cv"):
            kwargs = {}
            if train_set.weights is not None:
                kwargs[self.weights_arg_name] = train_set.weights

            estimator = sklearn.base.clone(self.estimator)
            estimator.fit(train_set.data, train_set.labels, **kwargs)
            self.estimators_.append(estimator)

            kwargs = {}
            if val_set.weights is not None:
                kwargs[self.weights_arg_name] = val_set.weights

            if self.score_fn is None:
                evals = {
                    "score":
                    estimator.score(val_set.data, val_set.labels, **kwargs)
                }
            else:
                pred_val = self._predict(val_set, fold)
                evals = self.score_fn(val_set.labels, pred_val)
            evals_list.append(evals)
            score_weights.append(len(val_set))

        evals = tk.evaluations.mean(evals_list, weights=score_weights)
        logger.info(f"cv: {tk.evaluations.to_str(evals)}")
Пример #2
0
 def _serial_cv(self, dataset: tk.data.Dataset, folds: tk.validation.FoldsType):
     evals_list = []
     evals_weights = []
     for fold, (train_set, val_set) in enumerate(dataset.iter(folds)):
         logger.info(f"fold{fold}: train={len(train_set)} val={len(val_set)}")
         evals = self.train(train_set, val_set, fold=fold)
         evals_list.append(evals)
         evals_weights.append(len(val_set))
     evals = tk.evaluations.mean(evals_list, weights=evals_weights)
     logger.info(f"cv: {tk.evaluations.to_str(evals)}")
Пример #3
0
    def _parallel_cv(self, dataset: tk.data.Dataset,
                     folds: tk.validation.FoldsType):
        for fold in range(len(folds)):
            self.create_network(fold)
        assert self.train_models[0] is not None

        inputs = []
        targets = []
        outputs = []
        losses = []
        metrics: typing.Dict[str, typing.Any] = {
            n: []
            for n in self.train_models[0].metrics_names if n != "loss"
        }
        for i, model in enumerate(self.train_models):
            assert model is not None
            input_shape = model.input_shape
            output_shape = model.output_shape
            if isinstance(input_shape, tuple):
                input_shape = [input_shape]
            if isinstance(output_shape, tuple):
                output_shape = [output_shape]

            model_inputs = [
                tf.keras.layers.Input(s[1:], name=f"model{i}_input{j}")
                for j, s in enumerate(input_shape)
            ]
            model_targets = [
                tf.keras.layers.Input(s[1:], name=f"model{i}_target{j}")
                for j, s in enumerate(output_shape)
            ]
            inputs.extend(model_inputs)
            targets.extend(model_targets)
            if len(model_targets) == 1:
                model_targets = model_targets[0]

            x = model(model_inputs)
            outputs.append(x)
            losses.extend(
                [loss(model_targets, x) for loss in model.loss_functions])
            assert len(metrics) == len(model.metrics)
            for k, m in zip(metrics, model.metrics):
                metrics[k].append(m(model_targets, x))

        def loss(y_true, y_pred):
            del y_true, y_pred
            return tf.math.reduce_mean(losses, axis=0)

        for k, v in metrics.items():

            def metric_func(y_true, y_pred, v=v):
                del y_true, y_pred
                return tf.math.reduce_mean(v, axis=0)

            metric_func.__name__ = k
            metrics[k] = metric_func

        model = tf.keras.models.Model(inputs=inputs + targets, outputs=outputs)
        model.compile(self.train_models[0].optimizer, loss,
                      list(metrics.values()))
        tk.models.summary(model)

        def generator(datasets, data_loader):
            iterators = [
                data_loader.iter(dataset, shuffle=True).ds
                for dataset in datasets
            ]
            while True:
                X_batch = {}
                for i, it in enumerate(iterators):
                    Xt, yt = next(it, (None, None))
                    assert Xt is not None
                    assert yt is not None

                    if isinstance(Xt, np.ndarray):
                        Xt = [Xt]
                    elif isinstance(Xt, dict):
                        Xt = Xt.values()  # TODO: 並び順
                    for j, Xtj in enumerate(Xt):
                        X_batch[f"model{i}_input{j}"] = Xtj

                    if isinstance(yt, np.ndarray):
                        yt = [yt]
                    elif isinstance(yt, dict):
                        yt = yt.values()  # TODO: 並び順
                    for j, ytj in enumerate(yt):
                        X_batch[f"model{i}_target{j}"] = ytj
                yield X_batch, None

        train_sets, val_sets = zip(*list(dataset.iter(folds)))

        model.fit(
            generator(train_sets, self.train_data_loader),
            steps_per_epoch=-(-len(train_sets[0]) //
                              self.train_data_loader.batch_size),
            validation_data=generator(val_sets, self.val_data_loader),
            validation_steps=-(-len(val_sets[0]) //
                               self.val_data_loader.batch_size),
            epochs=self.epochs,
            callbacks=self.callbacks,
            **(self.fit_params or {}),
        )

        evals = model.evaluate(
            generator(val_sets, self.val_data_loader),
            -(-len(val_sets[0]) // self.val_data_loader.batch_size) * 3,
        )
        scores = dict(zip(model.metrics_names, evals))
        for k, v in scores.items():
            tk.log.get(__name__).info(f"cv {k}: {v:,.3f}")