def get_cur_dataset(self, dataset_idx): dataset_name = self.dataset_list[dataset_idx] config_content = self.config_content fea_seperator = FeaSeperator(dataset_name) # set feature seperator seperator_type = config_content['feature_seperator'] if seperator_type == 'slice_window': window_size = config_content['window_size'] step = config_content['step'] n_level = config_content['n_level'] fea_seperator.set_seperator_by_slice_window( window_size, step, n_level) elif seperator_type == 'stride_window': stride_len = config_content['stride_len'] n_level = config_content['n_level'] fea_seperator.set_seperator_by_stride_window(stride_len, n_level) elif seperator_type == 'random_pick': window_size = config_content['window_size'] n_repeat = config_content['n_repeat_select'] n_level = config_content['n_level'] fea_seperator.set_seperator_by_random_pick(window_size, n_repeat, n_level) elif seperator_type == 'no_seperate': fea_seperator.set_seperator_by_no_seperate() # generate tree of rule numbers according to the seperator fea_seperator.generate_n_rule_tree(self.n_rules) # set rule number tree_rule_spesify = config_content['tree_rule_spesify'] if tree_rule_spesify == 'true': n_rule_pos = config_content['n_rule_pos'] n_rule_spesify = config_content['n_rule_spesify'] fea_seperator.set_n_rule_tree(n_rule_pos[0], n_rule_pos[1], n_rule_spesify) self.fea_seperator = fea_seperator return dataset_name
def forward(self, **kwargs): train_data: Dataset = kwargs['train_data'] test_data: Dataset = kwargs['test_data'] if 'seperator' not in kwargs: seperator = FeaSeperator(train_data.name).get_seperator() else: seperator: FeaSeperator = kwargs['seperator'] para_mu = kwargs['para_mu'] para_mu1 = kwargs['para_mu1'] fea_seperator = seperator.get_seperator() seperator.generate_n_rule_tree(kwargs['n_rules']) n_rules_tree = seperator.get_n_rule_tree() neuron_seed = self.get_neuron_seed() sub_seperator = fea_seperator[int(0)] sub_dataset_list = train_data.get_subset_fea(sub_seperator) sub_dataset_tmp = sub_dataset_list[0] n_rule_tmp = int(n_rules_tree[0][0]) n_smpl = sub_dataset_tmp.X.shape[0] n_fea_tmp = sub_dataset_tmp.X.shape[1] + 1 n_h = n_rule_tmp * n_fea_tmp h_all = torch.empty(0, n_smpl, n_h).double() n_branch = len(sub_seperator) # get neuron tree rules_tree: List[List[type(neuron_seed)]] = [] rules_sub: List[type(neuron_seed)] = [] # get output of every branches of upper dfnn layer for i in torch.arange(n_branch): neuron_seed.clear() neuron_c = neuron_seed.clone() sub_dataset_i = sub_dataset_list[i] kwargs['data'] = sub_dataset_i kwargs['n_rules'] = int(n_rules_tree[0][i]) neuron_c.forward(**kwargs) rules_sub.append(neuron_c) # get rules in neuron and update centers and bias rule_ao = neuron_c.get_rules() # get h computer in neuron h_computer_ao = neuron_c.get_h_computer() h_tmp, _ = h_computer_ao.comute_h(sub_dataset_i.X, rule_ao) h_cal_tmp = h_tmp.permute((1, 0, 2)) # N * n_rules * (d + 1) h_cal_tmp: torch.Tensor = h_cal_tmp.reshape(n_smpl, n_h) h_all = torch.cat((h_all, h_cal_tmp.unsqueeze(0)), 0) rules_tree.append(rules_sub) self.set_neuron_tree(rules_tree) # set bottom level AO n_midle_output = kwargs['n_hidden_output'] w_x = torch.rand(n_branch, n_midle_output, n_h).double() w_y = torch.rand(n_branch * n_midle_output, 1).double() # start AO optimization diff = 1 loss = 100 train_loss_list = [] run_th = 0.00001 n_epoch = 20 run_epoch = 1 loss_test_old = 100 # get h_all h_brunch_all = torch.empty(n_smpl, 0).double() for i in torch.arange(n_branch): h_brunch = h_all[i, :, :].repeat(1, n_midle_output) h_brunch_all = torch.cat((h_brunch_all, h_brunch), 1) # load better parameters data_save_dir = f"./para_file/{kwargs['dataset_name']}" if not os.path.exists(data_save_dir): os.makedirs(data_save_dir) para_file = f"{data_save_dir}/ao_{kwargs['n_rules']}_{kwargs['n_hidden_output']}.pt" # if os.path.exists(para_fi."] g_w_x_best = None g_w_y_best = None # while diff > run_th and run_epoch < n_epoch: while run_epoch < n_epoch: # fix w_y update w_x w_y_brunch_all = w_y.repeat(1, n_h) w_y_brunch_all = w_y_brunch_all.reshape(1, -1) w_x_h_brunch = torch.mul(h_brunch_all, w_y_brunch_all) w_x_brunch = torch.inverse( w_x_h_brunch.t().mm(w_x_h_brunch) + para_mu * torch.eye(w_x_h_brunch.shape[1]).double()).mm( w_x_h_brunch.t().mm(train_data.Y)) w_x_cal = w_x_brunch.squeeze().reshape(n_branch, n_midle_output * n_h) w_x = w_x_cal.reshape(n_branch, n_midle_output, n_h) # fix w_x update w_y w_y_h_cal = torch.empty(n_smpl, 0).double() for i in torch.arange(n_branch): w_y_h_brunch = h_all[i, :, :].mm(w_x[i, :, :].t()).squeeze() if n_midle_output == 1: w_y_h_cal = torch.cat( (w_y_h_cal, w_y_h_brunch.unsqueeze(1)), 1) else: w_y_h_cal = torch.cat((w_y_h_cal, w_y_h_brunch), 1) w_y = torch.inverse(w_y_h_cal.t().mm(w_y_h_cal) + para_mu1 * torch.eye(w_y_h_cal.shape[1]).double()).mm( w_y_h_cal.t().mm(train_data.Y)) # compute loss y_tmp = train_data.Y y_hap_tmp = w_y_h_cal.mm(w_y) # loss_tmp = torch.norm(y_tmp - y_hap_tmp) loss_tmp = mean_squared_error(y_tmp, y_hap_tmp) diff = abs(loss_tmp - loss) loss = loss_tmp train_loss_list.append(loss) self.__w_x = w_x self.__w_y = w_y # snoop the test data performance test_y = self.predict(test_data, kwargs['seperator']) loss_train = mean_squared_error(y_tmp, y_hap_tmp) loss_test = mean_squared_error(test_y, test_data.Y) # print(f"Loss of test: {loss_test}") if loss_test < loss_test_old and loss_test >= loss_train: g_w_x_best = w_x g_w_y_best = w_y run_epoch = run_epoch + 1 # print(f"Loss of AO: {loss} w_y: {w_y.squeeze()[1:5]}") # x = torch.linspace(1, len(train_loss_list) + 1, len(train_loss_list)).numpy() # y = train_loss_list # plt.title('Result Analysis') # plt.plot(x, y, color='green', label='loss value') # # plt.plot(x, test_acys, color='red', label='training accuracy') # plt.legend() # 显示图例 # # plt.xlabel('iteration epochs') # plt.ylabel('loss value') # plt.show() self.__w_x = w_x self.__w_y = w_y if g_w_x_best is not None: self.__w_x = g_w_x_best self.__w_y = g_w_y_best para_dict = dict() para_dict["w_x"] = w_x para_dict["w_y"] = w_y torch.save(para_dict, para_file)