def extract_gcn_v(opath_feat, opath_pred_confs, data_name, cfg, write_gcn_feat=False): if osp.isfile(opath_feat) and osp.isfile(opath_pred_confs): print('{} and {} already exist.'.format(opath_feat, opath_pred_confs)) return cfg.cuda = torch.cuda.is_available() logger = create_logger() model = build_model(cfg.model['type'], **cfg.model['kwargs']) for k, v in cfg.model['kwargs'].items(): setattr(cfg[data_name], k, v) cfg[data_name].eval_interim = False dataset = build_dataset(cfg.model['type'], cfg[data_name]) pred_confs, gcn_feat = test(model, dataset, cfg, logger) if not osp.exists(opath_pred_confs): logger.info('save predicted confs to {}'.format(opath_pred_confs)) mkdir_if_no_exists(opath_pred_confs) np.savez_compressed(opath_pred_confs, pred_confs=pred_confs, inst_num=dataset.inst_num) if not osp.exists(opath_feat) and write_gcn_feat: logger.info('save gcn features to {}'.format(opath_feat)) mkdir_if_no_exists(opath_feat) write_feat(opath_feat, gcn_feat)
def train_gcn_v(model, cfg, logger): # prepare dataset for k, v in cfg.model['kwargs'].items(): setattr(cfg.train_data, k, v) dataset = build_dataset(cfg.model['type'], cfg.train_data) # train if cfg.distributed: raise NotImplementedError else: _single_train(model, dataset, cfg, logger)
def train_gcn_e(model, cfg, logger): # prepare data loaders for k, v in cfg.model['kwargs'].items(): setattr(cfg.train_data, k, v) dataset = build_dataset(cfg.model['type'], cfg.train_data) data_loaders = [ build_dataloader(dataset, cfg.batch_size_per_gpu, cfg.workers_per_gpu, train=True, shuffle=True) ] # train if cfg.distributed: raise NotImplementedError else: _single_train(model, data_loaders, cfg, logger)
def test_gcn_e(model, cfg, logger): for k, v in cfg.model['kwargs'].items(): setattr(cfg.test_data, k, v) dataset = build_dataset(cfg.model['type'], cfg.test_data) pred_peaks = dataset.peaks pred_dist2peak = dataset.dist2peak ofn_pred = osp.join(cfg.work_dir, 'pred_conns.npz') if osp.isfile(ofn_pred) and not cfg.force: data = np.load(ofn_pred) pred_conns = data['pred_conns'] inst_num = data['inst_num'] if inst_num != dataset.inst_num: logger.warn( 'instance number in {} is different from dataset: {} vs {}'. format(ofn_pred, inst_num, len(dataset))) else: if cfg.random_conns: pred_conns = [] for nbr, dist, idx in zip(dataset.subset_nbrs, dataset.subset_dists, dataset.subset_idxs): for _ in range(cfg.max_conn): pred_rel_nbr = np.random.choice(np.arange(len(nbr))) pred_abs_nbr = nbr[pred_rel_nbr] pred_peaks[idx].append(pred_abs_nbr) pred_dist2peak[idx].append(dist[pred_rel_nbr]) pred_conns.append(pred_rel_nbr) pred_conns = np.array(pred_conns) else: pred_conns = test(model, dataset, cfg, logger) for pred_rel_nbr, nbr, dist, idx in zip(pred_conns, dataset.subset_nbrs, dataset.subset_dists, dataset.subset_idxs): pred_abs_nbr = nbr[pred_rel_nbr] pred_peaks[idx].extend(pred_abs_nbr) pred_dist2peak[idx].extend(dist[pred_rel_nbr]) inst_num = dataset.inst_num if len(pred_conns) > 0: logger.info( 'pred_conns (nbr order): mean({:.1f}), max({}), min({})'.format( pred_conns.mean(), pred_conns.max(), pred_conns.min())) if not dataset.ignore_label and cfg.eval_interim: subset_gt_labels = dataset.subset_gt_labels for i in range(cfg.max_conn): pred_peaks_labels = np.array([ dataset.idx2lb[pred_peaks[idx][i]] for idx in dataset.subset_idxs ]) acc = accuracy(pred_peaks_labels, subset_gt_labels) logger.info( '[{}-th] accuracy of pred_peaks labels ({}): {:.4f}'.format( i, len(pred_peaks_labels), acc)) # the rule for nearest nbr is only appropriate when nbrs is sorted nearest_idxs = np.where(pred_conns[:, i] == 0)[0] acc = accuracy(pred_peaks_labels[nearest_idxs], subset_gt_labels[nearest_idxs]) logger.info( '[{}-th] accuracy of pred labels (nearest: {}): {:.4f}'.format( i, len(nearest_idxs), acc)) not_nearest_idxs = np.where(pred_conns[:, i] > 0)[0] acc = accuracy(pred_peaks_labels[not_nearest_idxs], subset_gt_labels[not_nearest_idxs]) logger.info( '[{}-th] accuracy of pred labels (not nearest: {}): {:.4f}'. format(i, len(not_nearest_idxs), acc)) with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau)): pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau, inst_num) if cfg.save_output: logger.info( 'save predicted connectivity and labels to {}'.format(ofn_pred)) if not osp.isfile(ofn_pred) or cfg.force: np.savez_compressed(ofn_pred, pred_conns=pred_conns, inst_num=inst_num) # save clustering results idx2lb = list2dict(pred_labels, ignore_value=-1) folder = '{}_gcne_k_{}_th_{}_ig_{}'.format(cfg.test_name, cfg.knn, cfg.th_sim, cfg.test_data.ignore_ratio) opath_pred_labels = osp.join(cfg.work_dir, folder, 'tau_{}_pred_labels.txt'.format(cfg.tau)) mkdir_if_no_exists(opath_pred_labels) write_meta(opath_pred_labels, idx2lb, inst_num=inst_num) # evaluation if not dataset.ignore_label: print('==> evaluation') for metric in cfg.metrics: evaluate(dataset.gt_labels, pred_labels, metric) # H and C-scores gt_dict = {} pred_dict = {} for i in range(len(dataset.gt_labels)): gt_dict[str(i)] = dataset.gt_labels[i] pred_dict[str(i)] = pred_labels[i] bm = ClusteringBenchmark(gt_dict) scores = bm.evaluate_vmeasure(pred_dict) # fmi_scores = bm.evaluate_fowlkes_mallows_score(pred_dict) print(scores)
def test_gcn_v(model, cfg, logger): for k, v in cfg.model['kwargs'].items(): setattr(cfg.test_data, k, v) dataset = build_dataset(cfg.model['type'], cfg.test_data) folder = '{}_gcnv_k_{}_th_{}'.format(cfg.test_name, cfg.knn, cfg.th_sim) oprefix = osp.join(cfg.work_dir, folder) oname = osp.basename(rm_suffix(cfg.load_from)) opath_pred_confs = osp.join(oprefix, 'pred_confs', '{}.npz'.format(oname)) if osp.isfile(opath_pred_confs) and not cfg.force: data = np.load(opath_pred_confs) pred_confs = data['pred_confs'] inst_num = data['inst_num'] if inst_num != dataset.inst_num: logger.warn( 'instance number in {} is different from dataset: {} vs {}'. format(opath_pred_confs, inst_num, len(dataset))) else: pred_confs, gcn_feat = test(model, dataset, cfg, logger) inst_num = dataset.inst_num logger.info('pred_confs: mean({:.4f}). max({:.4f}), min({:.4f})'.format( pred_confs.mean(), pred_confs.max(), pred_confs.min())) logger.info('Convert to cluster') with Timer('Predition to peaks'): pred_dist2peak, pred_peaks = confidence_to_peaks( dataset.dists, dataset.nbrs, pred_confs, cfg.max_conn) if not dataset.ignore_label and cfg.eval_interim: # evaluate the intermediate results for i in range(cfg.max_conn): num = len(dataset.peaks) pred_peaks_i = np.arange(num) peaks_i = np.arange(num) for j in range(num): if len(pred_peaks[j]) > i: pred_peaks_i[j] = pred_peaks[j][i] if len(dataset.peaks[j]) > i: peaks_i[j] = dataset.peaks[j][i] acc = accuracy(pred_peaks_i, peaks_i) logger.info('[{}-th conn] accuracy of peak match: {:.4f}'.format( i + 1, acc)) acc = 0. for idx, peak in enumerate(pred_peaks_i): acc += int(dataset.idx2lb[peak] == dataset.idx2lb[idx]) acc /= len(pred_peaks_i) logger.info( '[{}-th conn] accuracy of peak label match: {:.4f}'.format( i + 1, acc)) with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau_0)): pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau_0, inst_num) if cfg.save_output: logger.info('save predicted confs to {}'.format(opath_pred_confs)) mkdir_if_no_exists(opath_pred_confs) np.savez_compressed(opath_pred_confs, pred_confs=pred_confs, inst_num=inst_num) # save clustering results idx2lb = list2dict(pred_labels, ignore_value=-1) opath_pred_labels = osp.join( cfg.work_dir, folder, 'tau_{}_pred_labels.txt'.format(cfg.tau_0)) logger.info('save predicted labels to {}'.format(opath_pred_labels)) mkdir_if_no_exists(opath_pred_labels) write_meta(opath_pred_labels, idx2lb, inst_num=inst_num) # evaluation if not dataset.ignore_label: print('==> evaluation') for metric in cfg.metrics: evaluate(dataset.gt_labels, pred_labels, metric) if cfg.use_gcn_feat: # gcn_feat is saved to disk for GCN-E opath_feat = osp.join(oprefix, 'features', '{}.bin'.format(oname)) if not osp.isfile(opath_feat) or cfg.force: mkdir_if_no_exists(opath_feat) write_feat(opath_feat, gcn_feat) name = rm_suffix(osp.basename(opath_feat)) prefix = oprefix ds = BasicDataset(name=name, prefix=prefix, dim=cfg.model['kwargs']['nhid'], normalize=True) ds.info() # use top embedding of GCN to rebuild the kNN graph with Timer('connect to higher confidence with use_gcn_feat'): knn_prefix = osp.join(prefix, 'knns', name) knns = build_knns(knn_prefix, ds.features, cfg.knn_method, cfg.knn, is_rebuild=True) dists, nbrs = knns2ordered_nbrs(knns) pred_dist2peak, pred_peaks = confidence_to_peaks( dists, nbrs, pred_confs, cfg.max_conn) pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau, inst_num) # save clustering results if cfg.save_output: oname_meta = '{}_gcn_feat'.format(name) opath_pred_labels = osp.join( oprefix, oname_meta, 'tau_{}_pred_labels.txt'.format(cfg.tau)) mkdir_if_no_exists(opath_pred_labels) idx2lb = list2dict(pred_labels, ignore_value=-1) write_meta(opath_pred_labels, idx2lb, inst_num=inst_num) # evaluation if not dataset.ignore_label: print('==> evaluation') for metric in cfg.metrics: evaluate(dataset.gt_labels, pred_labels, metric) import json import os import pdb pdb.set_trace() img_labels = json.load( open(r'/home/finn/research/data/clustering_data/test_index.json', 'r', encoding='utf-8')) import shutil output = r'/home/finn/research/data/clustering_data/mr_gcn_output' for label in set(pred_labels): if not os.path.exists(os.path.join(output, f'cluter_{label}')): os.mkdir(os.path.join(output, f'cluter_{label}')) for image in img_labels: shutil.copy2( image, os.path.join( os.path.join(output, f'cluter_{pred_labels[img_labels[image]]}'), os.path.split(image)[-1]))