Ejemplo n.º 1
0
 def get_data(self):
     """
     轨迹比较特殊,原子文件中存储的并不是轨迹而是一个一个点,因此需要先对轨迹进行切割
     """
     if self.data is None:
         if self.config['cache_dataset'] and os.path.exists(
                 self.encoder.cache_file_name):
             # load cache
             f = open(self.encoder.cache_file_name, 'r')
             self.data = json.load(f)
             self.pad_item = self.data['pad_item']
             f.close()
         else:
             cut_data = self.cutter_filter()
             encoded_data = self.encode_traj(cut_data)
             self.data = encoded_data
             self.pad_item = self.encoder.pad_item
             if self.config['cache_dataset']:
                 if not os.path.exists(self.cache_file_folder):
                     os.makedirs(self.cache_file_folder)
                 with open(self.encoder.cache_file_name, 'w') as f:
                     json.dump(encoded_data, f)
     # user 来划,以及按轨迹数来划。
     # TODO: 这里可以设一个参数,现在先按照轨迹数来划吧
     train_data, eval_data, test_data = self.divide_data()
     return generate_dataloader(train_data, eval_data, test_data,
                                self.encoder.feature_dict,
                                self.config['batch_size'],
                                self.config['num_workers'], self.pad_item,
                                self.encoder.feature_max_len)
    def get_data(self):
        """
        返回数据的DataLoader,包括训练数据、测试数据、验证数据

        Returns:
            tuple: tuple contains:
                train_dataloader: Dataloader composed of Batch (class) \n
                eval_dataloader: Dataloader composed of Batch (class) \n
                test_dataloader: Dataloader composed of Batch (class)
        """
        # 加载数据集
        x_train, y_train, x_val, y_val, x_test, y_test = [], [], [], [], [], []
        ext_x_train, ext_y_train, ext_x_test, ext_y_test, ext_x_val, ext_y_val = [], [], [], [], [], []
        if self.data is None:
            self.data = {}
            if self.cache_dataset and os.path.exists(self.cache_file_name):
                x_train, y_train, x_val, y_val, x_test, y_test,  \
                    ext_x_train, ext_y_train, ext_x_test, ext_y_test, ext_x_val, ext_y_val \
                    = self._load_cache_train_val_test()
            else:
                x_train, y_train, x_val, y_val, x_test, y_test, \
                    ext_x_train, ext_y_train, ext_x_test, ext_y_test, ext_x_val, ext_y_val \
                    = self._generate_train_val_test()
        # 数据归一化
        self.feature_dim = x_train.shape[-1]
        self.ext_dim = ext_x_train.shape[-1]
        self.scaler = self._get_scalar(self.scaler_type,
                                       x_train[..., :self.output_dim],
                                       y_train[..., :self.output_dim])
        self.ext_scaler = self._get_scalar(self.ext_scaler_type,
                                           x_train[..., self.output_dim:],
                                           y_train[..., self.output_dim:])
        x_train = self.scaler.transform(x_train)
        y_train = self.scaler.transform(y_train)
        x_val = self.scaler.transform(x_val)
        y_val = self.scaler.transform(y_val)
        x_test = self.scaler.transform(x_test)
        y_test = self.scaler.transform(y_test)
        if self.normal_external:
            ext_x_train = self.ext_scaler.transform(ext_x_train)
            ext_y_train = self.ext_scaler.transform(ext_y_train)
            ext_x_val = self.ext_scaler.transform(ext_x_val)
            ext_y_val = self.ext_scaler.transform(ext_y_val)
            ext_x_test = self.ext_scaler.transform(ext_x_test)
            ext_y_test = self.ext_scaler.transform(ext_y_test)
        # 把训练集的X和y聚合在一起成为list,测试集验证集同理
        # x_train/y_train: (num_samples, input_length, ..., feature_dim)
        # train_data(list): train_data[i]是一个元组,由x_train[i]和y_train[i]组成
        train_data = list(zip(x_train, y_train, ext_x_train, ext_y_train))
        eval_data = list(zip(x_val, y_val, ext_x_val, ext_y_val))
        test_data = list(zip(x_test, y_test, ext_x_test, ext_y_test))
        # 转Dataloader
        self.train_dataloader, self.eval_dataloader, self.test_dataloader = \
            generate_dataloader(train_data, eval_data, test_data, self.feature_name,
                                self.batch_size, self.num_workers, pad_with_last_sample=self.pad_with_last_sample)
        return self.train_dataloader, self.eval_dataloader, self.test_dataloader
    def get_data(self):
        x_train, y_train, flatten_att_nbhd_inputs_train, flatten_att_flow_inputs_train, att_lstm_inputs_train, nbhd_inputs_train, flow_inputs_train, lstm_inputs_train = [], [], [], [], [], [], [], []
        x_val, y_val, flatten_att_nbhd_inputs_val, flatten_att_flow_inputs_val, att_lstm_inputs_val, nbhd_inputs_val, flow_inputs_val, lstm_inputs_val = [], [], [], [], [], [], [], []
        x_test, y_test, flatten_att_nbhd_inputs_test, flatten_att_flow_inputs_test, att_lstm_inputs_test, nbhd_inputs_test, flow_inputs_test, lstm_inputs_test = [], [], [], [], [], [], [], []
        if self.data is None:
            self.data = {}
            if self.cache_dataset and os.path.exists(self.cache_file_name):
                x_train, y_train, flatten_att_nbhd_inputs_train, flatten_att_flow_inputs_train, att_lstm_inputs_train, nbhd_inputs_train, flow_inputs_train, lstm_inputs_train, \
                x_val, y_val, flatten_att_nbhd_inputs_val, flatten_att_flow_inputs_val, att_lstm_inputs_val, nbhd_inputs_val, flow_inputs_val, lstm_inputs_val, \
                x_test, y_test, flatten_att_nbhd_inputs_test, flatten_att_flow_inputs_test, att_lstm_inputs_test, nbhd_inputs_test, flow_inputs_test, lstm_inputs_test = self._load_cache_train_val_test()
            else:
                x_train, y_train, flatten_att_nbhd_inputs_train, flatten_att_flow_inputs_train, att_lstm_inputs_train, nbhd_inputs_train, flow_inputs_train, lstm_inputs_train, \
                x_val, y_val, flatten_att_nbhd_inputs_val, flatten_att_flow_inputs_val, att_lstm_inputs_val, nbhd_inputs_val, flow_inputs_val, lstm_inputs_val, \
                x_test, y_test, flatten_att_nbhd_inputs_test, flatten_att_flow_inputs_test, att_lstm_inputs_test, nbhd_inputs_test, flow_inputs_test, lstm_inputs_test = self._generate_train_val_test()
        self.feature_dim = x_train.shape[-1]
        self.feature_vec_len = lstm_inputs_train.shape[-1]
        self.nbhd_type = nbhd_inputs_train.shape[-1]
        self.scaler, self.flow_scaler = self._get_scalar_stdn(
            x_train, y_train, flow_inputs_train)
        x_train = self.scaler.transform(x_train)
        y_train = self.scaler.transform(y_train)
        flatten_att_nbhd_inputs_train = self.scaler.transform(
            flatten_att_nbhd_inputs_train)
        att_lstm_inputs_train = self.scaler.transform(att_lstm_inputs_train)
        nbhd_inputs_train = self.scaler.transform(nbhd_inputs_train)
        lstm_inputs_train = self.scaler.transform(lstm_inputs_train)
        x_val = self.scaler.transform(x_val)
        y_val = self.scaler.transform(y_val)
        flatten_att_nbhd_inputs_val = self.scaler.transform(
            flatten_att_nbhd_inputs_val)
        att_lstm_inputs_val = self.scaler.transform(att_lstm_inputs_val)
        nbhd_inputs_val = self.scaler.transform(nbhd_inputs_val)
        lstm_inputs_val = self.scaler.transform(lstm_inputs_val)
        x_test = self.scaler.transform(x_test)
        y_test = self.scaler.transform(y_test)
        flatten_att_nbhd_inputs_test = self.scaler.transform(
            flatten_att_nbhd_inputs_test)
        att_lstm_inputs_test = self.scaler.transform(att_lstm_inputs_test)
        nbhd_inputs_test = self.scaler.transform(nbhd_inputs_test)
        lstm_inputs_test = self.scaler.transform(lstm_inputs_test)

        flatten_att_flow_inputs_train = self.flow_scaler.transform(
            flatten_att_flow_inputs_train)
        flow_inputs_train = self.flow_scaler.transform(flow_inputs_train)
        flatten_att_flow_inputs_val = self.flow_scaler.transform(
            flatten_att_flow_inputs_val)
        flow_inputs_val = self.flow_scaler.transform(flow_inputs_val)
        flatten_att_flow_inputs_test = self.flow_scaler.transform(
            flatten_att_flow_inputs_test)
        flow_inputs_test = self.flow_scaler.transform(flow_inputs_test)

        train_data = list(
            zip(x_train, y_train, flatten_att_nbhd_inputs_train,
                flatten_att_flow_inputs_train, att_lstm_inputs_train,
                nbhd_inputs_train, flow_inputs_train, lstm_inputs_train))
        eval_data = list(
            zip(x_val, y_val, flatten_att_nbhd_inputs_val,
                flatten_att_flow_inputs_val, att_lstm_inputs_val,
                nbhd_inputs_val, flow_inputs_val, lstm_inputs_val))
        test_data = list(
            zip(x_test, y_test, flatten_att_nbhd_inputs_test,
                flatten_att_flow_inputs_test, att_lstm_inputs_test,
                nbhd_inputs_test, flow_inputs_test, lstm_inputs_test))
        self.train_dataloader, self.eval_dataloader, self.test_dataloader = \
            generate_dataloader(train_data, eval_data, test_data, self.feature_name,
                                self.batch_size, self.num_workers, pad_with_last_sample=self.pad_with_last_sample)
        return self.train_dataloader, self.eval_dataloader, self.test_dataloader