Beispiel #1
0
    def save_result(self, save_path, filename=None):
        """
        将评估结果保存到 save_path 文件夹下的 filename 文件中

        Args:
            save_path: 保存路径
            filename: 保存文件名
        """
        self._logger.info('Note that you select the {} mode to evaluate!'.format(self.mode))
        self.evaluate()
        ensure_dir(save_path)
        if filename is None:  # 使用时间戳
            filename = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') + '_' + \
                       self.config['model'] + '_' + self.config['dataset']
        self._logger.info('Evaluate result is ' + json.dumps(self.result))
        with open(os.path.join(save_path, '{}.json'.format(filename)), 'w') as f:
            json.dump(self.result, f)
        self._logger.info('Evaluate result is saved at ' +
                          os.path.join(save_path, '{}.json'.format(filename)))
        dataframe = {}
        for metric in self.metrics:
            dataframe[metric] = []
        for i in range(1, self.len_timeslots + 1):
            for metric in self.metrics:
                dataframe[metric].append(self.result[metric+'@'+str(i)])
        dataframe = pd.DataFrame(dataframe, index=range(1, self.len_timeslots + 1))
        dataframe.to_csv(os.path.join(save_path, '{}.csv'.format(filename)), index=False)
        self._logger.info('Evaluate result is saved at ' +
                          os.path.join(save_path, '{}.csv'.format(filename)))
        self._logger.info("\n" + str(dataframe))
        return dataframe
    def save_model(self, cache_name):
        """
        将当前的模型保存到文件

        Args:
            cache_name(str): 保存的文件名
        """
        ensure_dir(self.cache_dir)
        self._logger.info("Saved model at " + cache_name)
        torch.save((self.model.state_dict(), self.optimizer.state_dict()),
                   cache_name)
    def _split_train_val_test(self, x, y, df=None):
        """
        划分训练集、测试集、验证集,并缓存数据集

        Args:
            x(np.ndarray): 输入数据 (num_samples, input_length, ..., feature_dim)
            y(np.ndarray): 输出数据 (num_samples, input_length, ..., feature_dim)

        Returns:
            tuple: tuple contains:
                x_train: (num_samples, input_length, ..., feature_dim) \n
                y_train: (num_samples, input_length, ..., feature_dim) \n
                x_val: (num_samples, input_length, ..., feature_dim) \n
                y_val: (num_samples, input_length, ..., feature_dim) \n
                x_test: (num_samples, input_length, ..., feature_dim) \n
                y_test: (num_samples, input_length, ..., feature_dim)
        """
        test_rate = 1 - self.train_rate - self.eval_rate
        num_samples = x.shape[0]
        num_test = round(num_samples * test_rate)
        num_train = round(num_samples * self.train_rate)
        num_val = num_samples - num_test - num_train

        # train
        x_train, y_train = x[:num_train], y[:num_train]
        # val
        x_val, y_val = x[num_train:num_train +
                         num_val], y[num_train:num_train + num_val]
        # test
        x_test, y_test = x[-num_test:], y[-num_test:]
        self._logger.info("train\t" + "x: " + str(x_train.shape) + ", y: " +
                          str(y_train.shape))
        self._logger.info("eval\t" + "x: " + str(x_val.shape) + ", y: " +
                          str(y_val.shape))
        self._logger.info("test\t" + "x: " + str(x_test.shape) + ", y: " +
                          str(y_test.shape))

        self.adj_mx = self._generate_graph_with_data(data=df, length=num_test)
        if self.cache_dataset:
            ensure_dir(self.cache_file_folder)
            np.savez_compressed(self.cache_file_name,
                                x_train=x_train,
                                y_train=y_train,
                                x_test=x_test,
                                y_test=y_test,
                                x_val=x_val,
                                y_val=y_val,
                                adj_mx=self.adj_mx)
            self._logger.info('Saved at ' + self.cache_file_name)
        return x_train, y_train, x_val, y_val, x_test, y_test
    def save_model_with_epoch(self, epoch):
        """
        保存某个epoch的模型

        Args:
            epoch(int): 轮数
        """
        ensure_dir(self.cache_dir)
        config = dict()
        config['model_state_dict'] = self.model.state_dict()
        config['optimizer_state_dict'] = self.optimizer.state_dict()
        config['epoch'] = epoch
        model_path = self.cache_dir + '/' + self.config[
            'model'] + '_' + self.config['dataset'] + '_epoch%d.tar' % epoch
        torch.save(config, model_path)
        self._logger.info("Saved model at {}".format(epoch))
        return model_path
    def _split_train_val_test(self, x, y, ext_x=None, ext_y=None):
        """
        划分训练集、测试集、验证集,并缓存数据集

        Args:
            x(np.ndarray): 输入数据 (num_samples, T_c+T_p+T_t, ..., feature_dim)
            y(np.ndarray): 输出数据 (num_samples, 1, ..., feature_dim)
            ext_x(np.ndarray): 输入外部数据 (num_samples, T_c+T_p+T_t, ext_dim)
            ext_y(np.ndarray): 输出外部数据 (num_samples, ext_dim)

        Returns:
            tuple: tuple contains:
                x_train: (num_samples, T_c+T_p+T_t, ..., feature_dim) \n
                y_train: (num_samples, 1, ..., feature_dim) \n
                x_val: (num_samples, T_c+T_p+T_t, ..., feature_dim) \n
                y_val: (num_samples, 1, ..., feature_dim) \n
                x_test: (num_samples, T_c+T_p+T_t, ..., feature_dim) \n
                y_test: (num_samples, 1, ..., feature_dim) \n
                ext_x_train: (num_samples, T_c+T_p+T_t, ext_dim) \n
                ext_y_train: (num_samples, ext_dim) \n
                ext_x_val: (num_samples, T_c+T_p+T_t, ext_dim) \n
                ext_y_val: (num_samples, ext_dim) \n
                ext_x_test: (num_samples, T_c+T_p+T_t, ext_dim) \n
                ext_y_test: (num_samples, ext_dim)
        """
        test_rate = 1 - self.train_rate - self.eval_rate
        num_samples = x.shape[0]
        num_test = round(num_samples * test_rate)
        num_train = round(num_samples * self.train_rate)
        num_val = num_samples - num_test - num_train

        x_train, x_val, x_test = x[:num_train], x[num_train:num_train +
                                                  num_val], x[-num_test:]
        y_train, y_val, y_test = y[:num_train], y[num_train:num_train +
                                                  num_val], y[-num_test:]
        ext_x_train, ext_x_val, ext_x_test = ext_x[:num_train], ext_x[
            num_train:num_train + num_val], ext_x[-num_test:]
        ext_y_train, ext_y_val, ext_y_test = ext_y[:num_train], ext_y[
            num_train:num_train + num_val], ext_y[-num_test:]
        self._logger.info("train\t" + "x: " + str(x_train.shape) + ", y: " +
                          str(y_train.shape) + ", x_ext: " +
                          str(ext_x_train.shape) + ", y_ext: " +
                          str(ext_y_train.shape))
        self._logger.info("eval\t" + "x: " + str(x_val.shape) + ", y: " +
                          str(y_val.shape) + ", x_ext: " +
                          str(ext_x_val.shape) + ", y_ext: " +
                          str(ext_y_val.shape))
        self._logger.info("test\t" + "x: " + str(x_test.shape) + ", y: " +
                          str(y_test.shape) + ", x_ext: " +
                          str(ext_x_test.shape) + ", y_ext: " +
                          str(ext_y_test.shape))

        if self.cache_dataset:
            ensure_dir(self.cache_file_folder)
            np.savez_compressed(
                self.cache_file_name,
                x_train=x_train,
                y_train=y_train,
                x_test=x_test,
                y_test=y_test,
                x_val=x_val,
                y_val=y_val,
                ext_x_train=ext_x_train,
                ext_y_train=ext_y_train,
                ext_x_test=ext_x_test,
                ext_y_test=ext_y_test,
                ext_x_val=ext_x_val,
                ext_y_val=ext_y_val,
            )
            self._logger.info('Saved at ' + self.cache_file_name)
        return x_train, y_train, x_val, y_val, x_test, y_test, \
            ext_x_train, ext_y_train, ext_x_test, ext_y_test, ext_x_val, ext_y_val
Beispiel #6
0
 def __init__(self, config):
     self.config = config
     self.dataset = self.config.get('dataset', '')
     self.batch_size = self.config.get('batch_size', 64)
     self.cache_dataset = self.config.get('cache_dataset', True)
     self.num_workers = self.config.get('num_workers', 0)
     self.pad_with_last_sample = self.config.get('pad_with_last_sample',
                                                 True)
     self.train_rate = self.config.get('train_rate', 0.7)
     self.eval_rate = self.config.get('eval_rate', 0.1)
     self.scaler_type = self.config.get('scaler', 'none')
     self.ext_scaler_type = self.config.get('ext_scaler', 'none')
     self.load_external = self.config.get('load_external', False)
     self.normal_external = self.config.get('normal_external', False)
     self.add_time_in_day = self.config.get('add_time_in_day', False)
     self.add_day_in_week = self.config.get('add_day_in_week', False)
     self.input_window = self.config.get('input_window', 12)
     self.output_window = self.config.get('output_window', 12)
     self.parameters_str = \
         str(self.dataset) + '_' + str(self.input_window) + '_' + str(self.output_window) + '_' \
         + str(self.train_rate) + '_' + str(self.eval_rate) + '_' + str(self.scaler_type) + '_' \
         + str(self.batch_size) + '_' + str(self.load_external) + '_' + str(self.add_time_in_day) + '_' \
         + str(self.add_day_in_week) + '_' + str(self.pad_with_last_sample)
     self.cache_file_name = os.path.join(
         './libtraffic/cache/dataset_cache/',
         'traffic_state_{}.npz'.format(self.parameters_str))
     self.cache_file_folder = './libtraffic/cache/dataset_cache/'
     ensure_dir(self.cache_file_folder)
     self.data_path = './raw_data/' + self.dataset + '/'
     if not os.path.exists(self.data_path):
         raise ValueError("Dataset {} not exist! Please ensure the path "
                          "'./raw_data/{}/' exist!".format(
                              self.dataset, self.dataset))
     # 加载数据集的config.json文件
     self.weight_col = self.config.get('weight_col', '')
     self.data_col = self.config.get('data_col', '')
     self.ext_col = self.config.get('ext_col', '')
     self.geo_file = self.config.get('geo_file', self.dataset)
     self.rel_file = self.config.get('rel_file', self.dataset)
     self.data_files = self.config.get('data_files', self.dataset)
     self.ext_file = self.config.get('ext_file', self.dataset)
     self.output_dim = self.config.get('output_dim', 1)
     self.time_intervals = self.config.get('time_intervals', 300)  # s
     self.init_weight_inf_or_zero = self.config.get(
         'init_weight_inf_or_zero', 'inf')
     self.set_weight_link_or_dist = self.config.get(
         'set_weight_link_or_dist', 'dist')
     self.calculate_weight_adj = self.config.get('calculate_weight_adj',
                                                 False)
     self.weight_adj_epsilon = self.config.get('weight_adj_epsilon', 0.1)
     # 初始化
     self.data = None
     self.feature_name = {'X': 'float', 'y': 'float'}  # 此类的输入只有X和y
     self.adj_mx = None
     self.scaler = None
     self.ext_scaler = None
     self.feature_dim = 0
     self.ext_dim = 0
     self.num_nodes = 0
     self.num_batches = 0
     self._logger = getLogger()
     if os.path.exists(self.data_path + self.geo_file + '.geo'):
         self._load_geo()
     else:
         raise ValueError('Not found .geo file!')
     if os.path.exists(self.data_path + self.rel_file +
                       '.rel'):  # .rel file is not necessary
         self._load_rel()
     else:
         self.adj_mx = np.zeros((len(self.geo_ids), len(self.geo_ids)),
                                dtype=np.float32)
    def _split_train_val_test_stdn(self, x, y, flatten_att_nbhd_inputs,
                                   flatten_att_flow_inputs, att_lstm_inputs,
                                   nbhd_inputs, flow_inputs, lstm_inputs):
        """
        划分训练集、测试集、验证集,并缓存数据集

        Args:
            x(np.ndarray): 输入数据 (num_samples, input_length, ..., feature_dim)
            y(np.ndarray): 输出数据 (num_samples, input_length, ..., feature_dim)

        Returns:
            tuple: tuple contains:
                x_train: (num_samples, input_length, ..., feature_dim) \n
                y_train: (num_samples, input_length, ..., feature_dim) \n
                x_val: (num_samples, input_length, ..., feature_dim) \n
                y_val: (num_samples, input_length, ..., feature_dim) \n
                x_test: (num_samples, input_length, ..., feature_dim) \n
                y_test: (num_samples, input_length, ..., feature_dim)
        """
        test_rate = 1 - self.train_rate - self.eval_rate
        num_samples = x.shape[0]
        num_test = round(num_samples * test_rate)
        num_train = round(num_samples * self.train_rate)
        num_val = num_samples - num_test - num_train

        # train
        x_train = x[:num_train]
        y_train = y[:num_train]
        flatten_att_nbhd_inputs_train = flatten_att_nbhd_inputs[:num_train]
        flatten_att_flow_inputs_train = flatten_att_flow_inputs[:num_train]
        att_lstm_inputs_train = att_lstm_inputs[:num_train]
        nbhd_inputs_train = nbhd_inputs[:num_train]
        flow_inputs_train = flow_inputs[:num_train]
        lstm_inputs_train = lstm_inputs[:num_train]
        # val
        x_val = x[num_train:num_train + num_val]
        y_val = y[num_train:num_train + num_val]
        flatten_att_nbhd_inputs_val = flatten_att_nbhd_inputs[
            num_train:num_train + num_val]
        flatten_att_flow_inputs_val = flatten_att_flow_inputs[
            num_train:num_train + num_val]
        att_lstm_inputs_val = att_lstm_inputs[num_train:num_train + num_val]
        nbhd_inputs_val = nbhd_inputs[num_train:num_train + num_val]
        flow_inputs_val = flow_inputs[num_train:num_train + num_val]
        lstm_inputs_val = lstm_inputs[num_train:num_train + num_val]
        # test
        x_test = x[-num_test:]
        y_test = y[-num_test:]
        flatten_att_nbhd_inputs_test = flatten_att_nbhd_inputs[-num_test:]
        flatten_att_flow_inputs_test = flatten_att_flow_inputs[-num_test:]
        att_lstm_inputs_test = att_lstm_inputs[-num_test:]
        nbhd_inputs_test = nbhd_inputs[-num_test:]
        flow_inputs_test = flow_inputs[-num_test:]
        lstm_inputs_test = lstm_inputs[-num_test:]
        self._logger.info("train\t" + "x: " + str(x_train.shape) + "y: " +
                          str(y_train.shape) + "flatten_att_nbhd_inputs: " +
                          str(flatten_att_nbhd_inputs_train.shape) +
                          "flatten_att_flow_inputs: " +
                          str(flatten_att_flow_inputs_train.shape) +
                          "att_lstm_inputs: " +
                          str(att_lstm_inputs_train.shape) + "nbhd_inputs: " +
                          str(nbhd_inputs_train.shape) + "flow_inputs: " +
                          str(flow_inputs_train.shape) + "lstm_inputs: " +
                          str(lstm_inputs_train.shape))
        self._logger.info("eval\t" + "x: " + str(x_val.shape) + "y: " +
                          str(y_val.shape) + "flatten_att_nbhd_inputs: " +
                          str(flatten_att_nbhd_inputs_val.shape) +
                          "flatten_att_flow_inputs: " +
                          str(flatten_att_flow_inputs_val.shape) +
                          "att_lstm_inputs: " +
                          str(att_lstm_inputs_val.shape) + "nbhd_inputs: " +
                          str(nbhd_inputs_val.shape) + "flow_inputs: " +
                          str(flow_inputs_val.shape) + "lstm_inputs: " +
                          str(lstm_inputs_val.shape))
        self._logger.info("test\t" + "x: " + str(x_test.shape) + "y: " +
                          str(y_test.shape) + "flatten_att_nbhd_inputs: " +
                          str(flatten_att_nbhd_inputs_test.shape) +
                          "flatten_att_flow_inputs: " +
                          str(flatten_att_flow_inputs_test.shape) +
                          "att_lstm_inputs: " +
                          str(att_lstm_inputs_test.shape) + "nbhd_inputs: " +
                          str(nbhd_inputs_test.shape) + "flow_inputs: " +
                          str(flow_inputs_test.shape) + "lstm_inputs: " +
                          str(lstm_inputs_test.shape))

        if self.cache_dataset:
            ensure_dir(self.cache_file_folder)
            np.savez_compressed(
                self.cache_file_name,
                x_train=x_train,
                y_train=y_train,
                flatten_att_nbhd_inputs_train=flatten_att_nbhd_inputs_train,
                flatten_att_flow_inputs_train=flatten_att_flow_inputs_train,
                att_lstm_inputs_train=att_lstm_inputs_train,
                nbhd_inputs_train=nbhd_inputs_train,
                flow_inputs_train=flow_inputs_train,
                lstm_inputs_train=lstm_inputs_train,
                x_test=x_test,
                y_test=y_test,
                flatten_att_nbhd_inputs_test=flatten_att_nbhd_inputs_test,
                flatten_att_flow_inputs_test=flatten_att_flow_inputs_test,
                att_lstm_inputs_test=att_lstm_inputs_test,
                nbhd_inputs_test=nbhd_inputs_test,
                flow_inputs_test=flow_inputs_test,
                lstm_inputs_test=lstm_inputs_test,
                x_val=x_val,
                y_val=y_val,
                flatten_att_nbhd_inputs_val=flatten_att_nbhd_inputs_val,
                flatten_att_flow_inputs_val=flatten_att_flow_inputs_val,
                att_lstm_inputs_val=att_lstm_inputs_val,
                nbhd_inputs_val=nbhd_inputs_val,
                flow_inputs_val=flow_inputs_val,
                lstm_inputs_val=lstm_inputs_val,
            )
            self._logger.info('Saved at ' + self.cache_file_name)
        return x_train, y_train, flatten_att_nbhd_inputs_train, flatten_att_flow_inputs_train, att_lstm_inputs_train, nbhd_inputs_train, flow_inputs_train, lstm_inputs_train, \
               x_val, y_val, flatten_att_nbhd_inputs_val, flatten_att_flow_inputs_val, att_lstm_inputs_val, nbhd_inputs_val, flow_inputs_val, lstm_inputs_val, \
               x_test, y_test, flatten_att_nbhd_inputs_test, flatten_att_flow_inputs_test, att_lstm_inputs_test, nbhd_inputs_test, flow_inputs_test, lstm_inputs_test
    def __init__(self, config, model):
        self.evaluator = get_evaluator(config)
        self.config = config
        self.device = self.config.get('device', torch.device('cpu'))
        self.model = model.to(self.device)

        self.cache_dir = './libtraffic/cache/model_cache'
        self.evaluate_res_dir = './libtraffic/cache/evaluate_cache'
        self.summary_writer_dir = './libtraffic/log/runs'
        ensure_dir(self.cache_dir)
        ensure_dir(self.evaluate_res_dir)
        ensure_dir(self.summary_writer_dir)

        self._writer = SummaryWriter(self.summary_writer_dir)
        self._logger = getLogger()
        self._scaler = self.model.get_data_feature().get('scaler')
        for name, param in self.model.named_parameters():
            self._logger.info(
                str(name) + '\t' + str(param.shape) + '\t' +
                str(param.device) + '\t' + str(param.requires_grad))
        total_num = sum(
            [param.nelement() for param in self.model.parameters()])
        self._logger.info('Total parameter numbers: {}'.format(total_num))

        self.epochs = self.config.get('max_epoch', 100)
        self.train_loss = self.config.get('train_loss', 'none')
        self.learner = self.config.get('learner', 'adam')
        self.learning_rate = self.config.get('learning_rate', 0.01)
        self.weight_decay = self.config.get('weight_decay', 0)
        self.lr_beta1 = self.config.get('lr_beta1', 0.9)
        self.lr_beta2 = self.config.get('lr_beta2', 0.999)
        self.lr_betas = (self.lr_beta1, self.lr_beta2)
        self.lr_alpha = self.config.get('lr_alpha', 0.99)
        self.lr_epsilon = self.config.get('lr_epsilon', 1e-8)
        self.lr_momentum = self.config.get('lr_momentum', 0)
        self.lr_decay = self.config.get('lr_decay', False)
        self.lr_scheduler_type = self.config.get('lr_scheduler', 'multisteplr')
        self.lr_decay_ratio = self.config.get('lr_decay_ratio', 0.1)
        self.milestones = self.config.get('steps', [])
        self.step_size = self.config.get('step_size', 10)
        self.lr_lambda = self.config.get('lr_lambda', lambda x: x)
        self.lr_T_max = self.config.get('lr_T_max', 30)
        self.lr_eta_min = self.config.get('lr_eta_min', 0)
        self.lr_patience = self.config.get('lr_patience', 10)
        self.lr_threshold = self.config.get('lr_threshold', 1e-4)
        self.clip_grad_norm = self.config.get('clip_grad_norm', False)
        self.max_grad_norm = self.config.get('max_grad_norm', 1.)
        self.use_early_stop = self.config.get('use_early_stop', False)
        self.patience = self.config.get('patience', 50)
        self.log_every = self.config.get('log_every', 1)
        self.saved = self.config.get('saved_model', True)
        self.load_best_epoch = self.config.get('load_best_epoch', True)
        self.hyper_tune = self.config.get('hyper_tune', False)

        # self.output_dim = self.model.get_data_feature().get('output_dim', 1)
        self.output_dim = self.config.get('output_dim', 1)
        self.optimizer = self._build_optimizer()
        self.lr_scheduler = self._build_lr_scheduler()
        self._epoch_num = self.config.get('epoch', 0)
        if self._epoch_num > 0:
            self.load_model_with_epoch(self._epoch_num)
        self.loss_func = self._build_train_loss()
def hyper_parameter(task=None,
                    model_name=None,
                    dataset_name=None,
                    config_file=None,
                    space_file=None,
                    scheduler=None,
                    search_alg=None,
                    other_args=None,
                    num_samples=5,
                    max_concurrent=1,
                    cpu_per_trial=1,
                    gpu_per_trial=1):
    """ Use Ray tune to hyper parameter tune

    Args:
        task(str): task name
        model_name(str): model name
        dataset_name(str): dataset name
        config_file(str): config filename used to modify the pipeline's
            settings. the config file should be json.
        space_file(str): the file which specifies the parameter search space
        scheduler(str): the trial sheduler which will be used in ray.tune.run
        search_alg(str): the search algorithm
        other_args(dict): the rest parameter args, which will be pass to the Config
    """
    # load config
    experiment_config = ConfigParser(task,
                                     model_name,
                                     dataset_name,
                                     config_file=config_file,
                                     other_args=other_args)
    # logger
    logger = get_logger(experiment_config)
    # check space_file
    if space_file is None:
        logger.error(
            'the space_file should not be None when hyperparameter tune.')
        exit(0)
    # parse space_file
    search_sapce = parse_search_space(space_file)
    # load dataset
    dataset = get_dataset(experiment_config)
    # get train valid test data
    train_data, valid_data, test_data = dataset.get_data()
    data_feature = dataset.get_data_feature()

    def train(config,
              checkpoint_dir=None,
              experiment_config=None,
              train_data=None,
              valid_data=None,
              data_feature=None):
        """trainable function which meets ray tune API

        Args:
            config (dict): A dict of hyperparameter.
        """
        # modify experiment_config
        for key in config:
            if key in experiment_config:
                experiment_config[key] = config[key]
        experiment_config['hyper_tune'] = True
        logger = get_logger(experiment_config)
        logger.info(
            'Begin pipeline, task={}, model_name={}, dataset_name={}'.format(
                str(task), str(model_name), str(dataset_name)))
        logger.info('running parameters: ' + str(config))
        # load model
        model = get_model(experiment_config, data_feature)
        # load executor
        executor = get_executor(experiment_config, model)
        # checkpoint by ray tune
        if checkpoint_dir:
            checkpoint = os.path.join(checkpoint_dir, 'checkpoint')
            executor.load_model(checkpoint)
        # train
        executor.train(train_data, valid_data)

    # init search algorithm and scheduler
    if search_alg == 'BasicSearch':
        algorithm = BasicVariantGenerator()
    elif search_alg == 'BayesOptSearch':
        algorithm = BayesOptSearch(metric='loss', mode='min')
        # add concurrency limit
        algorithm = ConcurrencyLimiter(algorithm,
                                       max_concurrent=max_concurrent)
    elif search_alg == 'HyperOpt':
        algorithm = HyperOptSearch(metric='loss', mode='min')
        # add concurrency limit
        algorithm = ConcurrencyLimiter(algorithm,
                                       max_concurrent=max_concurrent)
    else:
        raise ValueError('the search_alg is illegal.')
    if scheduler == 'FIFO':
        tune_scheduler = FIFOScheduler()
    elif scheduler == 'ASHA':
        tune_scheduler = ASHAScheduler()
    elif scheduler == 'MedianStoppingRule':
        tune_scheduler = MedianStoppingRule()
    else:
        raise ValueError('the scheduler is illegal')
    # ray tune run
    ensure_dir('./libtraffic/cache/hyper_tune')
    result = tune.run(tune.with_parameters(train,
                                           experiment_config=experiment_config,
                                           train_data=train_data,
                                           valid_data=valid_data,
                                           data_feature=data_feature),
                      resources_per_trial={
                          'cpu': cpu_per_trial,
                          'gpu': gpu_per_trial
                      },
                      config=search_sapce,
                      metric='loss',
                      mode='min',
                      scheduler=tune_scheduler,
                      search_alg=algorithm,
                      local_dir='./libtraffic/cache/hyper_tune',
                      num_samples=num_samples)
    best_trial = result.get_best_trial("loss", "min", "last")
    logger.info("Best trial config: {}".format(best_trial.config))
    logger.info("Best trial final validation loss: {}".format(
        best_trial.last_result["loss"]))
    # save best
    best_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
    model_state, optimizer_state = torch.load(best_path)
    model_cache_file = './libtraffic/cache/model_cache/{}_{}.m'.format(
        model_name, dataset_name)
    ensure_dir('./libtraffic/cache/model_cache')
    torch.save((model_state, optimizer_state), model_cache_file)