def fit(self,
            data,
            epochs=1,
            batch_size=32,
            validation_data=None,
            metric_threshold=None,
            n_sampling=1,
            search_alg=None,
            search_alg_params=None,
            scheduler=None,
            scheduler_params=None):
        """
        fit using AutoEstimator

        :param data: train data.
               For backend of "torch", data can be a TSDataset or a function that takes a
               config dictionary as parameter and returns a PyTorch DataLoader.
               For backend of "keras", data can be a TSDataset.
        :param epochs: Max number of epochs to train in each trial. Defaults to 1.
               If you have also set metric_threshold, a trial will stop if either it has been
               optimized to the metric_threshold or it has been trained for {epochs} epochs.
        :param batch_size: Int or hp sampling function from an integer space. Training batch size.
               It defaults to 32.
        :param validation_data: Validation data. Validation data type should be the same as data.
        :param metric_threshold: a trial will be terminated when metric threshold is met.
        :param n_sampling: Number of times to sample from the search_space. Defaults to 1.
               If hp.grid_search is in search_space, the grid will be repeated n_sampling of times.
               If this is -1, (virtually) infinite samples are generated
               until a stopping condition is met.
        :param search_alg: str, all supported searcher provided by ray tune
               (i.e."variant_generator", "random", "ax", "dragonfly", "skopt",
               "hyperopt", "bayesopt", "bohb", "nevergrad", "optuna", "zoopt" and
               "sigopt")
        :param search_alg_params: extra parameters for searcher algorithm besides search_space,
               metric and searcher mode
        :param scheduler: str, all supported scheduler provided by ray tune
        :param scheduler_params: parameters for scheduler

        :return: a TSPipeline with the best model.
        """
        is_third_party_model = isinstance(self.model, AutoEstimator)

        # generate data creator from TSDataset (pytorch base require validation data)
        if isinstance(data, TSDataset) and isinstance(validation_data,
                                                      TSDataset):
            train_d, val_d = self._prepare_data_creator(
                search_space=self.search_space
                if is_third_party_model else self.model.search_space,
                train_data=data,
                val_data=validation_data,
            )
            self._scaler = data.scaler
            self._scaler_index = data.scaler_index
        else:
            train_d, val_d = data, validation_data

        if is_third_party_model:
            self.search_space.update({"batch_size": batch_size})
            self.model.fit(
                data=train_d,
                epochs=epochs,
                validation_data=val_d,
                metric=self.metric,
                metric_threshold=metric_threshold,
                n_sampling=n_sampling,
                search_space=self.search_space,
                search_alg=search_alg,
                search_alg_params=search_alg_params,
                scheduler=scheduler,
                scheduler_params=scheduler_params,
            )

        if not is_third_party_model:
            self.model.fit(data=train_d,
                           epochs=epochs,
                           batch_size=batch_size,
                           validation_data=val_d,
                           metric_threshold=metric_threshold,
                           n_sampling=n_sampling,
                           search_alg=search_alg,
                           search_alg_params=search_alg_params,
                           scheduler=scheduler,
                           scheduler_params=scheduler_params)

        return TSPipeline(best_model=self._get_best_automl_model(),
                          best_config=self.get_best_config(),
                          scaler=self._scaler,
                          scaler_index=self._scaler_index)
    def test_fit_tcn_feature(self):
        input_feature_dim = 11  # This param will not be used
        output_feature_dim = 2  # 2 targets are generated in get_tsdataset

        from sklearn.preprocessing import StandardScaler
        scaler = StandardScaler()
        tsdata_train = get_tsdataset().gen_dt_feature().scale(scaler, fit=True)
        tsdata_valid = get_tsdataset().gen_dt_feature().scale(scaler,
                                                              fit=False)

        search_space = {
            'hidden_units': hp.grid_search([32, 64]),
            'levels': hp.randint(4, 6),
            'kernel_size': hp.randint(3, 5),
            'dropout': hp.uniform(0.1, 0.2),
            'lr': hp.loguniform(0.001, 0.01)
        }
        auto_trainer = AutoTSTrainer(model='tcn',
                                     search_space=search_space,
                                     past_seq_len=hp.randint(4, 6),
                                     future_seq_len=1,
                                     input_feature_num=input_feature_dim,
                                     output_target_num=output_feature_dim,
                                     selected_features="auto",
                                     metric="mse",
                                     optimizer="Adam",
                                     loss=torch.nn.MSELoss(),
                                     logs_dir="/tmp/auto_trainer",
                                     cpus_per_trial=2,
                                     name="auto_trainer")
        ts_pipeline = auto_trainer.fit(data=tsdata_train,
                                       epochs=1,
                                       batch_size=hp.choice([32, 64]),
                                       validation_data=tsdata_valid,
                                       n_sampling=1)
        best_config = auto_trainer.get_best_config()
        best_model = auto_trainer.get_best_model()
        assert 4 <= best_config["past_seq_len"] <= 6

        assert isinstance(ts_pipeline, TSPipeline)

        # use raw base model to predic and evaluate
        tsdata_valid.roll(lookback=best_config["past_seq_len"],
                          horizon=0,
                          feature_col=best_config["selected_features"])
        x_valid, y_valid = tsdata_valid.to_numpy()
        y_pred_raw = best_model.predict(x_valid)
        y_pred_raw = tsdata_valid.unscale_numpy(y_pred_raw)

        # use tspipeline to predic and evaluate
        eval_result = ts_pipeline.evaluate(tsdata_valid)
        y_pred = ts_pipeline.predict(tsdata_valid)

        # check if they are the same
        np.testing.assert_almost_equal(y_pred, y_pred_raw)

        # save and load
        ts_pipeline.save("/tmp/auto_trainer/autots_tmp_model_tcn")
        new_ts_pipeline = TSPipeline.load(
            "/tmp/auto_trainer/autots_tmp_model_tcn")

        # check if load ppl is the same as previous
        eval_result_new = new_ts_pipeline.evaluate(tsdata_valid)
        y_pred_new = new_ts_pipeline.predict(tsdata_valid)
        np.testing.assert_almost_equal(eval_result[0], eval_result_new[0])
        np.testing.assert_almost_equal(y_pred, y_pred_new)

        # use tspipeline to incrementally train
        new_ts_pipeline.fit(tsdata_valid)