def gen_darts_dataset(base_path, save_path=None): files = [os.path.join(base_path, f) for f in os.listdir(base_path) if not f.endswith('txt')] all_archs = [] darts_dataset = DataSetDarts() for f in files: with open(f, 'rb') as fb: genotype_list = pickle.load(fb) keys_list = pickle.load(fb) for (genotype, key) in zip(genotype_list, keys_list): if len(all_archs) % 10000 == 0 and len(all_archs) != 0: print(f'{len(all_archs)} have processed!') f_new = convert_genotype_form(genotype, OPS) arch = (f_new.normal, f_new.reduce) arch_darts = ArchDarts(arch) path_encoding_position_aware = arch_darts.get_path( path_type='path_enc_aware_vec', seq_len=612 ) path_encoding = arch_darts.get_path( path_type='path_enc_vec', seq_len=612 ) path_adj_encoding = arch_darts.get_path( path_type='adj_enc_vec', seq_len=612 ) matrix, ops = darts_dataset.assemble_graph_from_single_arch(arch) edge_indices, node_features = nasbench2graph2((matrix, ops)) edge_reverse_indices, node_reverse_features = nasbench2graph2((matrix, ops), reverse=True) all_archs.append( { 'matrix': matrix, 'ops': ops, 'pe_adj_enc_vec': path_adj_encoding, 'pe_path_enc_vec': path_encoding, 'pe_path_enc_aware_vec': path_encoding_position_aware, 'hash_key': key, 'genotype': genotype, 'edge_idx': edge_indices, 'node_f': node_features, 'g_data': Data(edge_index=edge_indices.long(), x=node_features.float()), 'edge_idx_reverse': edge_reverse_indices, 'node_f_reverse': node_reverse_features, 'g_data_reverse': Data(edge_index=edge_reverse_indices.long(), x=node_reverse_features.float()), } ) if save_path: with open(save_path, 'wb') as fb: pickle.dump(all_archs, fb) return all_archs
def fit(self, edge_index, node_feature, edge_index_reverse, node_feature_reverse, val_accuracy): meters = MetricLogger(delimiter=" ") self.stage1.train() for epoch in range(self.epochs): idx_list = list(range(len(edge_index))) random.shuffle(idx_list) batch_idx_list = gen_batch_idx(idx_list, 10) for i, batch_idx in enumerate(batch_idx_list): data_list = [] data_list_reverse = [] target_list = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) target_list.append(val_accuracy[idx]) val_tensor = torch.tensor(target_list, dtype=torch.float32) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) val_tensor = val_tensor.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred = self.stage1(batch_nodes, batch_edge_idx, batch_idx, batch_nodes_reverse, batch_edge_idx_reverse) pred = pred.squeeze() loss = self.criterion(pred, val_tensor) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step(epoch + int(i / 30)) meters.update(loss=loss.item()) return meters.meters['loss'].avg
def pred(self, edge_index, node_feature): pred_list = [] mean_list = [] std_list = [] idx_list = list(range(len(edge_index))) self.nas_agent.eval() batch_idx_list = gen_batch_idx(idx_list, 64) with torch.no_grad(): for batch_idx in batch_idx_list: data_list = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch pred, mean, std = self.nas_agent(batch_nodes, batch_edge_idx, batch_idx) pred = pred.squeeze() mean = mean.squeeze() std = std.squeeze() if len(pred.size()) == 0: pred.unsqueeze_(0) mean.unsqueeze_(0) std.unsqueeze_(0) pred_list.append(pred) mean_list.append(mean) std_list.append(std) return torch.cat(pred_list, dim=0), torch.cat(mean_list, dim=0), torch.cat(std_list, dim=0)
def fit(self, edge_index, node_feature, val_accuracy, logger=None): meters = MetricLogger(delimiter=" ") self.nas_agent.train() for epoch in range(self.epoch): idx_list = list(range(len(edge_index))) random.shuffle(idx_list) batch_idx_list = gen_batch_idx(idx_list, self.batch_size) counter = 0 for batch_idx in batch_idx_list: counter += len(batch_idx) data_list = [] target_list = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) target_list.append(val_accuracy[idx]) val_tensor = torch.tensor(target_list, dtype=torch.float32) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) val_tensor = val_tensor.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch # batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) pred, mean, std = self.nas_agent(batch_nodes, batch_edge_idx, batch_idx) val_tensor = val_tensor.unsqueeze(dim=1) loss = self.criterion(mean, std, val_tensor) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() meters.update(loss=loss.item()) if logger: logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))])) return meters.meters['loss'].avg
def inference(self, edge_index, node_feature, arch_encoding, batch_idx_test_1, batch_idx_test_2): self.predictor.eval() error_list = [] precision_list = [] for i, pair1_idx in enumerate(batch_idx_test_1): pair2_idx = batch_idx_test_2[i] data_list_pair1 = [] arch_path_encoding_pair1 = [] data_list_pair2 = [] arch_path_encoding_pair2 = [] for pair_idx in zip(pair1_idx, pair2_idx): idx1, idx2 = pair_idx g_d_1 = Data(edge_index=edge_index[idx1].long(), x=node_feature[idx1].float()) data_list_pair1.append(g_d_1) arch_path_encoding_pair1.append(arch_encoding[idx1]) g_d_2 = Data(edge_index=edge_index[idx2].long(), x=node_feature[idx2].float()) data_list_pair2.append(g_d_2) arch_path_encoding_pair2.append(arch_encoding[idx2]) dist_gt = torch.tensor([edit_distance(arch_path_encoding_pair1[i], arch_path_encoding_pair2[i]) for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32) batch1 = Batch.from_data_list(data_list_pair1) batch1 = batch1.to(self.device) batch2 = Batch.from_data_list(data_list_pair2) batch2 = batch2.to(self.device) dist_gt = dist_gt.to(self.device) batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2, batch_edge_idx_2, batch_idx_2) prediction = -1 * torch.log(prediction.squeeze(dim=-1)) * self.node_num errors = torch.abs(dist_gt - prediction) precision = (torch.sum(errors < 1) * 1.) / errors.size(0) error_list.append(torch.mean(errors).item()) precision_list.append(precision.item()) if self.logger: self.logger.info(f'Error is {np.mean(np.array(error_list))}, Precision is {np.mean(np.array(precision_list))}') else: print(f'Error is {np.mean(np.array(error_list))}, Precision is {np.mean(np.array(precision_list))}')
def __getitem__(self, idx): id = self.idxs[idx] if self.model_type == 'moco': arch = self.total_archs[self.total_keys[id]][0][0] ops = self.total_archs[self.total_keys[id]][0][1] path_encoding = self.total_archs[self.total_keys[id]][-1] return arch, ops, path_encoding elif self.model_type == 'SS_CCL': arch = self.total_archs[self.total_keys[id]][0][0] ops = self.total_archs[self.total_keys[id]][0][1] path_encoding = self.total_archs[self.total_keys[id]][-1] edge_index, node_f = nasbench2graph_201((arch, ops), is_idx=True) g_d = Data(edge_index=edge_index.long(), x=node_f.float()) return g_d, path_encoding else: raise NotImplementedError( f'The model type {self.model_type} does not support!')
def fit_train(self, edge_index, node_feature, accuracy, logger=None): meters = MetricLogger(delimiter=" ") self.predictor.train() start = time.time() for epoch in range(self.epoch): idx_list = list(range(len(edge_index))) random.shuffle(idx_list) batch_idx_list = gen_batch_idx(idx_list, self.batch_size) counter = 0 for i, batch_idx in enumerate(batch_idx_list): counter += len(batch_idx) data_list = [] target_list = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) target_list.append(accuracy[idx]) val_tensor = torch.tensor(target_list, dtype=torch.float32) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) val_tensor = val_tensor.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch # batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) pred = self.predictor(batch_nodes, batch_edge_idx, batch_idx) pred = pred.squeeze() loss = self.criterion(pred, val_tensor) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() meters.update(loss=loss.item()) # print(meters.delimiter.join(['{loss}'.format(loss=str(meters))])) save_dir = os.path.join(self.save_dir, f'supervised_gin_epoch_{epoch}.pt') if self.save_model: torch.save(self.predictor.state_dict(), save_dir) return meters.meters['loss'].avg
def fit(self, edge_index, node_feature, edge_index_reverse, node_feature_reverse, val_accuracy, val_accuracy_cls): meters_cls = MetricLogger(delimiter=" ") meters_regerss = MetricLogger(delimiter=" ") self.stage1.train() self.stage2.train() for epoch in range(self.epochs): idx_list = list(range(len(edge_index))) random.shuffle(idx_list) batch_idx_list = gen_batch_idx(idx_list, 10) for i, batch_idx in enumerate(batch_idx_list): data_list = [] data_list_reverse = [] target_list_cls = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) target_list_cls.append(val_accuracy_cls[idx]) val_cls_tensor = torch.tensor(target_list_cls, dtype=torch.long) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) val_cls_tensor = val_cls_tensor.to(self.device) batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred_cls = self.stage1(batch_nodes, batch_edge_idx, batch_idx_g, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() loss_stage1 = self.criterion_ce(pred_cls, val_cls_tensor) self.optimizer_cls.zero_grad() loss_stage1.backward() self.optimizer_cls.step() self.scheduler_cls.step(epoch + int(i / 30)) meters_cls.update(loss=loss_stage1.item()) if pred_cls.dim() == 1: pred_cls.unsequeeze_(dim=0) pred_max = torch.argmax(pred_cls, dim=1) if torch.sum(pred_max) == 0: continue data_list = [] data_list_reverse = [] target_list = [] for k, idx in enumerate(batch_idx): if pred_max[k] == 0: continue g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) target_list.append(val_accuracy[idx]) val_tensor = torch.tensor(target_list, dtype=torch.float32) val_tensor = val_tensor.to(self.device) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx_G = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred_regress = self.stage2(batch_nodes, batch_edge_idx, batch_idx_G, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() if pred_regress.dim() == 0: loss_stage2 = self.criterion(pred_regress, val_tensor[0]) else: loss_stage2 = self.criterion(pred_regress, val_tensor) self.optimizer_regress.zero_grad() loss_stage2.backward() self.optimizer_regress.step() self.scheduler_regress.step(epoch + int(i / 30)) meters_regerss.update(loss=loss_stage2.item()) return meters_cls.meters['loss'].avg, meters_regerss.meters['loss'].avg
def pred(self, edge_index, node_feature, edge_index_reverse, node_feature_reverse): pred_list = [] idx_list = list(range(len(edge_index))) self.stage1.eval() self.stage2.eval() batch_idx_list = gen_batch_idx(idx_list, 32) with torch.no_grad(): for i, batch_idx in enumerate(batch_idx_list): data_list = [] data_list_reverse = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred_cls = self.stage1(batch_nodes, batch_edge_idx, batch_idx_g, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() if pred_cls.dim() == 1: pred_cls.unsqueeze_(dim=0) pred_max = torch.argmax(pred_cls, dim=1) if pred_max.dim() == 0: pred_max.unsqueeze_(dim=0) if torch.sum(pred_max) == 0: continue data_list = [] data_list_reverse = [] for k, idx in enumerate(batch_idx): if pred_max[k] == 0: continue g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred_regress = self.stage2(batch_nodes, batch_edge_idx, batch_idx_g, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() pred = torch.zeros_like(pred_max, dtype=torch.float32) if pred_regress.dim() == 0: pred_regress.unsqueeze_(dim=0) # print(pred_max.size(), pred_regress.size(), pred.size()) counter = 0 for j in range(pred.size(0)): if pred_max[j] == 0: pred[j] = 0 else: pred[j] = pred_regress[counter] counter += 1 if len(pred.size()) == 0: pred.unsqueeze_(0) pred_list.append(pred) return torch.cat(pred_list, dim=0)
def dataset_all(args, nas_dataset): total_keys = nas_dataset.total_keys total_archs = nas_dataset.total_archs all_archs = [] flag = args.search_space == 'nasbench_101' for k in total_keys: arch = total_archs[k] if args.search_space == 'nasbench_101': cell_inst = Cell_101(matrix=arch['matrix'], ops=arch['ops']) edge_index, node_f = nas2graph(args.search_space, (arch['matrix'], arch['ops'])) g_data = Data(edge_index=edge_index.long(), x=node_f.float()) seminas_vec = convert_arch_to_seq(arch['o_matrix'], arch['o_ops']) edge_index_reverse, node_f_reverse = nasbench2graph_reverse( (arch['matrix'], arch['ops']), reverse=True) g_data_reverse = Data(edge_index=edge_index_reverse.long(), x=node_f_reverse.float()) if len(seminas_vec) < 27: padding = 27 - len(seminas_vec) seminas_vec = seminas_vec + [0 for _ in range(padding)] all_archs.append({ 'matrix': arch['matrix'] if flag else arch[0][0], 'ops': arch['ops'] if flag else arch[0][1], 'pe_adj_enc_vec': cell_inst.get_encoding('adj_enc_vec', args.seq_len), 'pe_path_enc_vec': cell_inst.get_encoding('path_enc_vec', args.seq_len), 'pe_path_enc_aware_vec': cell_inst.get_encoding('path_enc_aware_vec', args.seq_len), 'val_acc': arch['val'] if flag else (100 - arch[4]) * 0.01, 'test_acc': arch['test'] if flag else (100 - arch[5]) * 0.01, 'g_data': g_data, 'arch_k': k, 'seminas_vec': seminas_vec, 'edge_idx': edge_index, 'node_f': node_f, 'edge_idx_reverse': edge_index_reverse, 'node_f_reverse': node_f_reverse, 'g_data_reverse': g_data_reverse }) elif args.search_space == 'nasbench_201': cell_inst = Cell_201(matrix=arch[0][0], ops=arch[0][1]) edge_index, node_f = nas2graph(args.search_space, (arch[0][0], arch[0][1])) edge_index_reverse, node_f_reverse = nas2graph( args.search_space, (arch[0][0], arch[0][1]), reverse=True) g_data_reverse = Data(edge_index=edge_index_reverse.long(), x=node_f_reverse.float()) all_archs.append({ 'matrix': arch['matrix'] if flag else arch[0][0], 'ops': arch['ops'] if flag else arch[0][1], 'pe_adj_enc_vec': cell_inst.get_encoding('adj_enc_vec', args.seq_len), 'pe_path_enc_vec': cell_inst.get_encoding('path_enc_vec', args.seq_len), 'pe_path_enc_aware_vec': cell_inst.get_encoding('path_enc_aware_vec', args.seq_len), 'val_acc': arch['val'] if flag else (100 - arch[4]) * 0.01, 'test_acc': arch['test'] if flag else (100 - arch[5]) * 0.01, 'g_data': Data(edge_index=edge_index.long(), x=node_f.float()), 'arch_k': k, 'edge_idx': edge_index, 'node_f': node_f, 'edge_idx_reverse': edge_index_reverse, 'node_f_reverse': node_f_reverse, 'g_data_reverse': g_data_reverse }) else: raise NotImplementedError() return all_archs