Exemplo n.º 1
0
def gen_darts_dataset(base_path, save_path=None):
    files = [os.path.join(base_path, f) for f in os.listdir(base_path) if not f.endswith('txt')]
    all_archs = []
    darts_dataset = DataSetDarts()
    for f in files:
        with open(f, 'rb') as fb:
            genotype_list = pickle.load(fb)
            keys_list = pickle.load(fb)
        for (genotype, key) in zip(genotype_list, keys_list):
            if len(all_archs) % 10000 == 0 and len(all_archs) != 0:
                print(f'{len(all_archs)} have processed!')
            f_new = convert_genotype_form(genotype, OPS)
            arch = (f_new.normal, f_new.reduce)
            arch_darts = ArchDarts(arch)
            path_encoding_position_aware = arch_darts.get_path(
                path_type='path_enc_aware_vec',
                seq_len=612
            )
            path_encoding = arch_darts.get_path(
                path_type='path_enc_vec',
                seq_len=612
            )
            path_adj_encoding = arch_darts.get_path(
                path_type='adj_enc_vec',
                seq_len=612
            )

            matrix, ops = darts_dataset.assemble_graph_from_single_arch(arch)
            edge_indices, node_features = nasbench2graph2((matrix, ops))
            edge_reverse_indices, node_reverse_features = nasbench2graph2((matrix, ops), reverse=True)

            all_archs.append(
                {
                    'matrix': matrix,
                    'ops': ops,
                    'pe_adj_enc_vec': path_adj_encoding,
                    'pe_path_enc_vec': path_encoding,
                    'pe_path_enc_aware_vec': path_encoding_position_aware,
                    'hash_key': key,
                    'genotype': genotype,
                    'edge_idx': edge_indices,
                    'node_f': node_features,
                    'g_data': Data(edge_index=edge_indices.long(), x=node_features.float()),
                    'edge_idx_reverse': edge_reverse_indices,
                    'node_f_reverse': node_reverse_features,
                    'g_data_reverse': Data(edge_index=edge_reverse_indices.long(), x=node_reverse_features.float()),
                }
            )
    if save_path:
        with open(save_path, 'wb') as fb:
            pickle.dump(all_archs, fb)
    return all_archs
Exemplo n.º 2
0
    def fit(self, edge_index, node_feature, edge_index_reverse,
            node_feature_reverse, val_accuracy):
        meters = MetricLogger(delimiter="  ")
        self.stage1.train()
        for epoch in range(self.epochs):
            idx_list = list(range(len(edge_index)))
            random.shuffle(idx_list)
            batch_idx_list = gen_batch_idx(idx_list, 10)
            for i, batch_idx in enumerate(batch_idx_list):
                data_list = []
                data_list_reverse = []
                target_list = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)
                    target_list.append(val_accuracy[idx])

                val_tensor = torch.tensor(target_list, dtype=torch.float32)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                val_tensor = val_tensor.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)

                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)

                pred = self.stage1(batch_nodes, batch_edge_idx, batch_idx,
                                   batch_nodes_reverse, batch_edge_idx_reverse)
                pred = pred.squeeze()
                loss = self.criterion(pred, val_tensor)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step(epoch + int(i / 30))
                meters.update(loss=loss.item())
        return meters.meters['loss'].avg
    def pred(self, edge_index, node_feature):
        pred_list = []
        mean_list = []
        std_list = []
        idx_list = list(range(len(edge_index)))
        self.nas_agent.eval()
        batch_idx_list = gen_batch_idx(idx_list, 64)
        with torch.no_grad():
            for batch_idx in batch_idx_list:
                data_list = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float())
                    data_list.append(g_d)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                pred, mean, std = self.nas_agent(batch_nodes, batch_edge_idx, batch_idx)

                pred = pred.squeeze()
                mean = mean.squeeze()
                std = std.squeeze()
                if len(pred.size()) == 0:
                    pred.unsqueeze_(0)
                    mean.unsqueeze_(0)
                    std.unsqueeze_(0)
                pred_list.append(pred)
                mean_list.append(mean)
                std_list.append(std)
        return torch.cat(pred_list, dim=0), torch.cat(mean_list, dim=0), torch.cat(std_list, dim=0)
    def fit(self, edge_index, node_feature, val_accuracy, logger=None):
        meters = MetricLogger(delimiter=" ")
        self.nas_agent.train()
        for epoch in range(self.epoch):
            idx_list = list(range(len(edge_index)))
            random.shuffle(idx_list)
            batch_idx_list = gen_batch_idx(idx_list, self.batch_size)
            counter = 0
            for batch_idx in batch_idx_list:
                counter += len(batch_idx)
                data_list = []
                target_list = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float())
                    data_list.append(g_d)
                    target_list.append(val_accuracy[idx])
                val_tensor = torch.tensor(target_list, dtype=torch.float32)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                val_tensor = val_tensor.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                # batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)

                pred, mean, std = self.nas_agent(batch_nodes, batch_edge_idx, batch_idx)
                val_tensor = val_tensor.unsqueeze(dim=1)
                loss = self.criterion(mean, std, val_tensor)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                meters.update(loss=loss.item())
        if logger:
            logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
        return meters.meters['loss'].avg
Exemplo n.º 5
0
    def inference(self, edge_index, node_feature, arch_encoding, batch_idx_test_1, batch_idx_test_2):
        self.predictor.eval()
        error_list = []
        precision_list = []
        for i, pair1_idx in enumerate(batch_idx_test_1):
            pair2_idx = batch_idx_test_2[i]
            data_list_pair1 = []
            arch_path_encoding_pair1 = []

            data_list_pair2 = []
            arch_path_encoding_pair2 = []

            for pair_idx in zip(pair1_idx, pair2_idx):
                idx1, idx2 = pair_idx
                g_d_1 = Data(edge_index=edge_index[idx1].long(), x=node_feature[idx1].float())
                data_list_pair1.append(g_d_1)
                arch_path_encoding_pair1.append(arch_encoding[idx1])

                g_d_2 = Data(edge_index=edge_index[idx2].long(), x=node_feature[idx2].float())
                data_list_pair2.append(g_d_2)
                arch_path_encoding_pair2.append(arch_encoding[idx2])

            dist_gt = torch.tensor([edit_distance(arch_path_encoding_pair1[i], arch_path_encoding_pair2[i])
                                    for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32)
            batch1 = Batch.from_data_list(data_list_pair1)
            batch1 = batch1.to(self.device)

            batch2 = Batch.from_data_list(data_list_pair2)
            batch2 = batch2.to(self.device)

            dist_gt = dist_gt.to(self.device)

            batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch
            batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch
            prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2,
                                        batch_edge_idx_2, batch_idx_2)
            prediction = -1 * torch.log(prediction.squeeze(dim=-1)) * self.node_num

            errors = torch.abs(dist_gt - prediction)
            precision = (torch.sum(errors < 1) * 1.) / errors.size(0)
            error_list.append(torch.mean(errors).item())
            precision_list.append(precision.item())
        if self.logger:
            self.logger.info(f'Error is {np.mean(np.array(error_list))}, Precision is {np.mean(np.array(precision_list))}')
        else:
            print(f'Error is {np.mean(np.array(error_list))}, Precision is {np.mean(np.array(precision_list))}')
Exemplo n.º 6
0
    def __getitem__(self, idx):
        id = self.idxs[idx]

        if self.model_type == 'moco':
            arch = self.total_archs[self.total_keys[id]][0][0]
            ops = self.total_archs[self.total_keys[id]][0][1]
            path_encoding = self.total_archs[self.total_keys[id]][-1]
            return arch, ops, path_encoding
        elif self.model_type == 'SS_CCL':
            arch = self.total_archs[self.total_keys[id]][0][0]
            ops = self.total_archs[self.total_keys[id]][0][1]
            path_encoding = self.total_archs[self.total_keys[id]][-1]

            edge_index, node_f = nasbench2graph_201((arch, ops), is_idx=True)
            g_d = Data(edge_index=edge_index.long(), x=node_f.float())
            return g_d, path_encoding
        else:
            raise NotImplementedError(
                f'The model type {self.model_type} does not support!')
Exemplo n.º 7
0
    def fit_train(self, edge_index, node_feature, accuracy, logger=None):
        meters = MetricLogger(delimiter=" ")
        self.predictor.train()
        start = time.time()
        for epoch in range(self.epoch):
            idx_list = list(range(len(edge_index)))
            random.shuffle(idx_list)
            batch_idx_list = gen_batch_idx(idx_list, self.batch_size)
            counter = 0
            for i, batch_idx in enumerate(batch_idx_list):
                counter += len(batch_idx)
                data_list = []
                target_list = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float())
                    data_list.append(g_d)
                    target_list.append(accuracy[idx])
                val_tensor = torch.tensor(target_list, dtype=torch.float32)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                val_tensor = val_tensor.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                # batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)

                pred = self.predictor(batch_nodes, batch_edge_idx, batch_idx)
                pred = pred.squeeze()
                loss = self.criterion(pred, val_tensor)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                meters.update(loss=loss.item())
            # print(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
            save_dir = os.path.join(self.save_dir, f'supervised_gin_epoch_{epoch}.pt')
            if self.save_model:
                torch.save(self.predictor.state_dict(), save_dir)
        return meters.meters['loss'].avg
Exemplo n.º 8
0
    def fit(self, edge_index, node_feature, edge_index_reverse,
            node_feature_reverse, val_accuracy, val_accuracy_cls):
        meters_cls = MetricLogger(delimiter="  ")
        meters_regerss = MetricLogger(delimiter="  ")
        self.stage1.train()
        self.stage2.train()
        for epoch in range(self.epochs):
            idx_list = list(range(len(edge_index)))
            random.shuffle(idx_list)
            batch_idx_list = gen_batch_idx(idx_list, 10)

            for i, batch_idx in enumerate(batch_idx_list):
                data_list = []
                data_list_reverse = []
                target_list_cls = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)
                    target_list_cls.append(val_accuracy_cls[idx])

                val_cls_tensor = torch.tensor(target_list_cls,
                                              dtype=torch.long)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                val_cls_tensor = val_cls_tensor.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)
                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)

                pred_cls = self.stage1(batch_nodes, batch_edge_idx,
                                       batch_idx_g, batch_nodes_reverse,
                                       batch_edge_idx_reverse).squeeze()
                loss_stage1 = self.criterion_ce(pred_cls, val_cls_tensor)
                self.optimizer_cls.zero_grad()
                loss_stage1.backward()
                self.optimizer_cls.step()
                self.scheduler_cls.step(epoch + int(i / 30))
                meters_cls.update(loss=loss_stage1.item())

                if pred_cls.dim() == 1:
                    pred_cls.unsequeeze_(dim=0)

                pred_max = torch.argmax(pred_cls, dim=1)
                if torch.sum(pred_max) == 0:
                    continue
                data_list = []
                data_list_reverse = []
                target_list = []
                for k, idx in enumerate(batch_idx):
                    if pred_max[k] == 0:
                        continue
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)
                    target_list.append(val_accuracy[idx])
                val_tensor = torch.tensor(target_list, dtype=torch.float32)
                val_tensor = val_tensor.to(self.device)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx_G = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)
                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)
                pred_regress = self.stage2(batch_nodes, batch_edge_idx,
                                           batch_idx_G, batch_nodes_reverse,
                                           batch_edge_idx_reverse).squeeze()
                if pred_regress.dim() == 0:
                    loss_stage2 = self.criterion(pred_regress, val_tensor[0])
                else:
                    loss_stage2 = self.criterion(pred_regress, val_tensor)
                self.optimizer_regress.zero_grad()
                loss_stage2.backward()
                self.optimizer_regress.step()
                self.scheduler_regress.step(epoch + int(i / 30))
                meters_regerss.update(loss=loss_stage2.item())
        return meters_cls.meters['loss'].avg, meters_regerss.meters['loss'].avg
Exemplo n.º 9
0
    def pred(self, edge_index, node_feature, edge_index_reverse,
             node_feature_reverse):
        pred_list = []
        idx_list = list(range(len(edge_index)))
        self.stage1.eval()
        self.stage2.eval()
        batch_idx_list = gen_batch_idx(idx_list, 32)
        with torch.no_grad():
            for i, batch_idx in enumerate(batch_idx_list):
                data_list = []
                data_list_reverse = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)
                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)
                pred_cls = self.stage1(batch_nodes, batch_edge_idx,
                                       batch_idx_g, batch_nodes_reverse,
                                       batch_edge_idx_reverse).squeeze()
                if pred_cls.dim() == 1:
                    pred_cls.unsqueeze_(dim=0)
                pred_max = torch.argmax(pred_cls, dim=1)
                if pred_max.dim() == 0:
                    pred_max.unsqueeze_(dim=0)
                if torch.sum(pred_max) == 0:
                    continue
                data_list = []
                data_list_reverse = []
                for k, idx in enumerate(batch_idx):
                    if pred_max[k] == 0:
                        continue
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)

                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)
                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)
                pred_regress = self.stage2(batch_nodes, batch_edge_idx,
                                           batch_idx_g, batch_nodes_reverse,
                                           batch_edge_idx_reverse).squeeze()
                pred = torch.zeros_like(pred_max, dtype=torch.float32)
                if pred_regress.dim() == 0:
                    pred_regress.unsqueeze_(dim=0)
                # print(pred_max.size(), pred_regress.size(), pred.size())
                counter = 0
                for j in range(pred.size(0)):
                    if pred_max[j] == 0:
                        pred[j] = 0
                    else:
                        pred[j] = pred_regress[counter]
                        counter += 1
                if len(pred.size()) == 0:
                    pred.unsqueeze_(0)
                pred_list.append(pred)
        return torch.cat(pred_list, dim=0)
Exemplo n.º 10
0
def dataset_all(args, nas_dataset):
    total_keys = nas_dataset.total_keys
    total_archs = nas_dataset.total_archs
    all_archs = []

    flag = args.search_space == 'nasbench_101'
    for k in total_keys:
        arch = total_archs[k]
        if args.search_space == 'nasbench_101':
            cell_inst = Cell_101(matrix=arch['matrix'], ops=arch['ops'])
            edge_index, node_f = nas2graph(args.search_space,
                                           (arch['matrix'], arch['ops']))
            g_data = Data(edge_index=edge_index.long(), x=node_f.float())
            seminas_vec = convert_arch_to_seq(arch['o_matrix'], arch['o_ops'])
            edge_index_reverse, node_f_reverse = nasbench2graph_reverse(
                (arch['matrix'], arch['ops']), reverse=True)
            g_data_reverse = Data(edge_index=edge_index_reverse.long(),
                                  x=node_f_reverse.float())
            if len(seminas_vec) < 27:
                padding = 27 - len(seminas_vec)
                seminas_vec = seminas_vec + [0 for _ in range(padding)]
            all_archs.append({
                'matrix':
                arch['matrix'] if flag else arch[0][0],
                'ops':
                arch['ops'] if flag else arch[0][1],
                'pe_adj_enc_vec':
                cell_inst.get_encoding('adj_enc_vec', args.seq_len),
                'pe_path_enc_vec':
                cell_inst.get_encoding('path_enc_vec', args.seq_len),
                'pe_path_enc_aware_vec':
                cell_inst.get_encoding('path_enc_aware_vec', args.seq_len),
                'val_acc':
                arch['val'] if flag else (100 - arch[4]) * 0.01,
                'test_acc':
                arch['test'] if flag else (100 - arch[5]) * 0.01,
                'g_data':
                g_data,
                'arch_k':
                k,
                'seminas_vec':
                seminas_vec,
                'edge_idx':
                edge_index,
                'node_f':
                node_f,
                'edge_idx_reverse':
                edge_index_reverse,
                'node_f_reverse':
                node_f_reverse,
                'g_data_reverse':
                g_data_reverse
            })
        elif args.search_space == 'nasbench_201':
            cell_inst = Cell_201(matrix=arch[0][0], ops=arch[0][1])
            edge_index, node_f = nas2graph(args.search_space,
                                           (arch[0][0], arch[0][1]))
            edge_index_reverse, node_f_reverse = nas2graph(
                args.search_space, (arch[0][0], arch[0][1]), reverse=True)
            g_data_reverse = Data(edge_index=edge_index_reverse.long(),
                                  x=node_f_reverse.float())
            all_archs.append({
                'matrix':
                arch['matrix'] if flag else arch[0][0],
                'ops':
                arch['ops'] if flag else arch[0][1],
                'pe_adj_enc_vec':
                cell_inst.get_encoding('adj_enc_vec', args.seq_len),
                'pe_path_enc_vec':
                cell_inst.get_encoding('path_enc_vec', args.seq_len),
                'pe_path_enc_aware_vec':
                cell_inst.get_encoding('path_enc_aware_vec', args.seq_len),
                'val_acc':
                arch['val'] if flag else (100 - arch[4]) * 0.01,
                'test_acc':
                arch['test'] if flag else (100 - arch[5]) * 0.01,
                'g_data':
                Data(edge_index=edge_index.long(), x=node_f.float()),
                'arch_k':
                k,
                'edge_idx':
                edge_index,
                'node_f':
                node_f,
                'edge_idx_reverse':
                edge_index_reverse,
                'node_f_reverse':
                node_f_reverse,
                'g_data_reverse':
                g_data_reverse
            })
        else:
            raise NotImplementedError()
    return all_archs