예제 #1
0
def extract_gcn_v(opath_feat,
                  opath_pred_confs,
                  data_name,
                  cfg,
                  write_gcn_feat=False):
    if osp.isfile(opath_feat) and osp.isfile(opath_pred_confs):
        print('{} and {} already exist.'.format(opath_feat, opath_pred_confs))
        return
    cfg.cuda = torch.cuda.is_available()

    logger = create_logger()

    model = build_model(cfg.model['type'], **cfg.model['kwargs'])

    for k, v in cfg.model['kwargs'].items():
        setattr(cfg[data_name], k, v)
    cfg[data_name].eval_interim = False

    dataset = build_dataset(cfg.model['type'], cfg[data_name])

    pred_confs, gcn_feat = test(model, dataset, cfg, logger)

    if not osp.exists(opath_pred_confs):
        logger.info('save predicted confs to {}'.format(opath_pred_confs))
        mkdir_if_no_exists(opath_pred_confs)
        np.savez_compressed(opath_pred_confs,
                            pred_confs=pred_confs,
                            inst_num=dataset.inst_num)

    if not osp.exists(opath_feat) and write_gcn_feat:
        logger.info('save gcn features to {}'.format(opath_feat))
        mkdir_if_no_exists(opath_feat)
        write_feat(opath_feat, gcn_feat)
예제 #2
0
 def __init__(self, feats, k, index_path='', verbose=True):
     import nmslib
     self.verbose = verbose
     with Timer('[hnsw] build index', verbose):
         """ higher ef leads to better accuracy, but slower search
             higher M leads to higher accuracy/run_time at fixed ef, but consumes more memory
         """
         # space_params = {
         #     'ef': 100,
         #     'M': 16,
         # }
         # index = nmslib.init(method='hnsw', space='cosinesimil', space_params=space_params)
         index = nmslib.init(method='hnsw', space='cosinesimil')
         if index_path != '' and os.path.isfile(index_path):
             index.loadIndex(index_path)
         else:
             index.addDataPointBatch(feats)
             index.createIndex({
                 'post': 2,
                 'indexThreadQty': 1
             },
                               print_progress=verbose)
             if index_path:
                 print('[hnsw] save index to {}'.format(index_path))
                 mkdir_if_no_exists(index_path)
                 index.saveIndex(index_path)
     with Timer('[hnsw] query topk {}'.format(k), verbose):
         knn_ofn = index_path + '.npz'
         if os.path.exists(knn_ofn):
             print('[hnsw] read knns from {}'.format(knn_ofn))
             self.knns = [(knn[0, :].astype(np.int32), knn[1, :].astype(np.float32)) \
                             for knn in np.load(knn_ofn)['data']]
         else:
             self.knns = index.knnQueryBatch(feats, k=k)
예제 #3
0
def build_knns(knn_prefix,
               feats,
               knn_method,
               k,
               num_process=None,
               is_rebuild=False):
    knn_prefix = os.path.join(knn_prefix, '{}_k_{}'.format(knn_method, k))
    mkdir_if_no_exists(knn_prefix)
    knn_path = knn_prefix + '.npz'
    if not os.path.isfile(knn_path) or is_rebuild:
        index_path = knn_prefix + '.index'
        with Timer('build index'):
            if knn_method == 'hnsw':
                index = knn_hnsw(feats, k, index_path)
            elif knn_method == 'faiss':
                index = knn_faiss(feats,
                                  k,
                                  index_path,
                                  omp_num_threads=num_process)
            elif knn_method == 'faiss_gpu':
                index = knn_faiss_gpu(feats,
                                      k,
                                      index_path,
                                      num_process=num_process)
            else:
                raise KeyError('Unsupported method({}). \
                        Only support hnsw and faiss currently'.format(
                    knn_method))
            knns = index.get_knns()
        with Timer('dump knns to {}'.format(knn_path)):
            dump_data(knn_path, knns, force=True)
    else:
        print('read knn from {}'.format(knn_path))
        knns = load_data(knn_path)
    return knns
예제 #4
0
    def __init__(self, feats, k, index_path='', verbose=True, **kwargs):
        import nmslib
        self.verbose = verbose
        with Timer('[hnsw] build index', verbose):

            # space_params = {
            #     'ef': 100,
            #     'M': 16,
            # }
            # index = nmslib.init(method='hnsw',
            #                     space='cosinesimil',
            #                     space_params=space_params)
            index = nmslib.init(method='hnsw', space='cosinesimil')
            if index_path != '' and os.path.isfile(index_path):
                index.loadIndex(index_path)
            else:
                index.addDataPointBatch(feats)
                index.createIndex({
                    'post': 2,
                    'indexThreadQty': 1
                },
                                  print_progress=verbose)
                if index_path:
                    print('[hnsw] save index to {}'.format(index_path))
                    mkdir_if_no_exists(index_path)
                    index.saveIndex(index_path)
        with Timer('[hnsw] query topk {}'.format(k), verbose):
            knn_ofn = index_path + '.npz'
            if os.path.exists(knn_ofn):
                print('[hnsw] read knns from {}'.format(knn_ofn))
                self.knns = np.load(knn_ofn)['data']
            else:
                self.knns = index.knnQueryBatch(feats, k=k)
def get_output_path(args, ofn='pred_labels.txt'):
    method2name = {
        'aro':
        'k_{}_th_{}'.format(args.knn, args.th_sim),
        'knn_aro':
        'k_{}_th_{}'.format(args.knn, args.th_sim),
        'dbscan':
        'eps_{}_min_{}'.format(args.eps, args.min_samples),
        'knn_dbscan':
        'eps_{}_min_{}_{}_k_{}_th_{}'.format(args.eps, args.min_samples,
                                             args.knn_method, args.knn,
                                             args.th_sim),
        'our_dbscan':
        'min_{}_k_{}_th_{}'.format(args.min_samples, args.knn, args.th_sim),
        'hdbscan':
        'min_{}'.format(args.min_samples),
        'fast_hierarchy':
        'dist_{}_hmethod_{}'.format(args.distance, args.hmethod),
        'hierarchy':
        'n_{}_k_{}'.format(args.n_clusters, args.knn),
        'knn_hierarchy':
        'n_{}_k_{}_th_{}'.format(args.n_clusters, args.knn, args.th_sim),
        'mini_batch_kmeans':
        'n_{}_bs_{}'.format(args.n_clusters, args.batch_size),
        'kmeans':
        'n_{}'.format(args.n_clusters),
        'spectral':
        'n_{}'.format(args.n_clusters),
        'dask_spectral':
        'n_{}'.format(args.n_clusters),
        'knn_spectral':
        'n_{}_k_{}_th_{}'.format(args.n_clusters, args.knn, args.th_sim),
        'densepeak':
        'k_{}_th_{}_r_{}_m_{}'.format(args.knn, args.th_sim, args.radius,
                                      args.min_conn),
        'meanshift':
        'bw_{}_bin_{}'.format(args.bw, args.min_bin_freq),
        'chinese_whispers':
        '{}_k_{}_th_{}_iters_{}'.format(args.knn_method, args.knn, args.th_sim,
                                        args.iters),
        'chinese_whispers_fast':
        '{}_k_{}_th_{}_iters_{}'.format(args.knn_method, args.knn, args.th_sim,
                                        args.iters),
    }

    if args.method in method2name:
        name = '{}_{}_{}'.format(args.name, args.method,
                                 method2name[args.method])
    else:
        name = '{}_{}'.format(args.name, args.method)

    opath = os.path.join(args.oprefix, name, ofn)
    if os.path.exists(opath) and not args.force:
        raise FileExistsError(
            '{} has already existed. Please set force=True to overwrite.'.
            format(opath))
    mkdir_if_no_exists(opath)

    return opath
예제 #6
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    # set cuda
    cfg.cuda = not args.no_cuda and torch.cuda.is_available()

    # set cudnn_benchmark & cudnn_deterministic
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    if cfg.get('cudnn_deterministic', False):
        torch.backends.cudnn.deterministic = True

    # update configs according to args
    if not hasattr(cfg, 'work_dir'):
        if args.work_dir is not None:
            cfg.work_dir = args.work_dir
        else:
            cfg_name = rm_suffix(os.path.basename(args.config))
            cfg.work_dir = os.path.join('./data/work_dir', cfg_name)
    mkdir_if_no_exists(cfg.work_dir, is_folder=True)

    cfg.load_from = args.load_from
    cfg.resume_from = args.resume_from

    cfg.gpus = args.gpus
    cfg.distributed = args.distributed

    cfg.random_conns = args.random_conns
    cfg.eval_interim = args.eval_interim
    cfg.save_output = args.save_output
    cfg.force = args.force

    for data in ['train_data', 'test_data']:
        if not hasattr(cfg, data):
            continue
        cfg[data].eval_interim = cfg.eval_interim
        # import pdb
        # pdb.set_trace()
        if not hasattr(cfg[data], 'knn_graph_path') or not os.path.isfile(
                cfg[data].knn_graph_path):
            cfg[data].prefix = cfg.prefix
            cfg[data].knn = cfg.knn
            cfg[data].knn_method = cfg.knn_method
            name = 'train_name' if data == 'train_data' else 'test_name'
            cfg[data].name = cfg[name]

    logger = create_logger()

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model = build_model(cfg.model['type'], **cfg.model['kwargs'])
    handler = build_handler(args.phase, cfg.model['type'])

    handler(model, cfg, logger)
예제 #7
0
 def __init__(self,
              feats,
              k,
              index_path='',
              index_key='',
              nprobe=128,
              omp_num_threads=None,
              rebuild_index=True,
              verbose=True,
              **kwargs):
     import faiss
     if omp_num_threads is not None:
         faiss.omp_set_num_threads(omp_num_threads)
     self.verbose = verbose
     with Timer('[my faiss gpu] build index', verbose):
         if index_path != '' and not rebuild_index and os.path.exists(
                 index_path):
             print('[my faiss gpu] read index from {}'.format(index_path))
             index = faiss.read_index(index_path)
         else:
             feats = feats.astype('float32')
             size, dim = feats.shape
             res = faiss.StandardGpuResources()
             index = faiss.GpuIndexFlatIP(res, dim)
             if index_key != '':
                 assert index_key.find(
                     'HNSW') < 0, 'HNSW returns distances insted of sims'
                 metric = faiss.METRIC_INNER_PRODUCT
                 nlist = min(4096, 8 * round(math.sqrt(size)))
                 if index_key == 'IVF':
                     quantizer = index
                     index = faiss.IndexIVFFlat(quantizer, dim, nlist,
                                                metric)
                 else:
                     index = faiss.index_factory(dim, index_key, metric)
                 if index_key.find('Flat') < 0:
                     assert not index.is_trained
                 index.train(feats)
                 index.nprobe = min(nprobe, nlist)
                 assert index.is_trained
                 print('nlist: {}, nprobe: {}'.format(nlist, nprobe))
             index.add(feats)
             if index_path != '':
                 print('[my faiss gpu] save index to {}'.format(index_path))
                 mkdir_if_no_exists(index_path)
                 index_cpu = faiss.index_gpu_to_cpu(index)
                 faiss.write_index(index_cpu, index_path)
     with Timer('[my faiss gpu] query topk {}'.format(k), verbose):
         knn_ofn = index_path + '.npz'
         if os.path.exists(knn_ofn):
             print('[my faiss gpu] read knns from {}'.format(knn_ofn))
             self.knns = np.load(knn_ofn)['data']
         else:
             sims, nbrs = index.search(feats, k=k)
             self.knns = [(np.array(nbr, dtype=np.int32),
                           1 - np.array(sim, dtype=np.float32))
                          for nbr, sim in zip(nbrs, sims)]
예제 #8
0
파일: main.py 프로젝트: daip13/LPC_TRAIN
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cuda
    cfg.cuda = not args.no_cuda and torch.cuda.is_available()
    # set cudnn_benchmark & cudnn_deterministic
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    if cfg.get('cudnn_deterministic', False):
        torch.backends.cudnn.deterministic = True
    # update configs according to args
    if not hasattr(cfg, 'work_dir'):
        if args.work_dir is not None:
            cfg.work_dir = args.work_dir
        else:
            cfg_name = rm_suffix(os.path.basename(args.config))
            cfg.work_dir = os.path.join('./data/work_dir', cfg_name)
    mkdir_if_no_exists(cfg.work_dir, is_folder=True)
    if not hasattr(cfg, 'stage'):
        cfg.stage = args.stage

    cfg.load_from1 = args.load_from1
    cfg.load_from2 = args.load_from2
    cfg.load_from3 = args.load_from3
    cfg.resume_from = args.resume_from

    #cfg.gpus = args.gpus
    cfg.distributed = args.distributed
    cfg.save_output = args.save_output
    cfg.phase = args.phase
    logger = create_logger()

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model = [build_model(cfg.model1['type'], **cfg.model1['kwargs']), \
            build_model(cfg.model2['type'], **cfg.model2['kwargs']), \
            build_model(cfg.model3['type'], **cfg.model3['kwargs'])]
    if cfg.phase == 'train':
        if cfg.load_from1:
            model1, model2, model3 = model[0], model[1], model[2]
            model1.load_state_dict(torch.load(cfg.load_from1))
            model[0] = model1
        if cfg.load_from2:
            model2.load_state_dict(torch.load(cfg.load_from2))
            model[1] = model2
        if cfg.load_from3:
            model3.load_state_dict(torch.load(cfg.load_from3))
            model[2] = model3
    handler = build_handler(args.phase, args.stage)

    handler(model, cfg, logger)
예제 #9
0
def build_knns(
        knn_prefix,
        feats,
        knn_method,
        k,
        num_process=16,  # default None
        is_rebuild=False,
        feat_create_time=None):
    knn_prefix = os.path.join(knn_prefix, '{}_k_{}'.format(knn_method, k))
    mkdir_if_no_exists(knn_prefix)
    knn_path = knn_prefix + '.npz'
    if os.path.isfile(
            knn_path) and not is_rebuild and feat_create_time is not None:
        knn_create_time = os.path.getmtime(knn_path)
        if knn_create_time <= feat_create_time:
            print('[warn] knn is created before feats ({} vs {})'.format(
                format_time(knn_create_time), format_time(feat_create_time)))
            is_rebuild = True
    if not os.path.isfile(knn_path) or is_rebuild:
        index_path = knn_prefix + '.index'
        with Timer('build index'):
            if knn_method == 'hnsw':
                index = knn_hnsw(feats, k, index_path)
            elif knn_method == 'faiss':
                index = knn_faiss(feats,
                                  k,
                                  index_path,
                                  omp_num_threads=num_process,
                                  rebuild_index=True)
            elif knn_method == 'faiss_gpu':
                # index = knn_faiss_my_gpu(feats,
                #                   k,
                #                   index_path,
                #                   omp_num_threads=num_process,
                #                   rebuild_index=True)
                index = knn_faiss_gpu(feats,
                                      k,
                                      index_path,
                                      num_process=num_process)
            else:
                raise KeyError(
                    'Only support hnsw and faiss currently ({}).'.format(
                        knn_method))
            knns = index.get_knns()
        with Timer('dump knns to {}'.format(knn_path)):
            dump_data(knn_path, knns, force=True)
    else:
        print('read knn from {}'.format(knn_path))
        knns = load_data(knn_path)
    return knns
예제 #10
0
 def __init__(self,
              feats,
              k,
              index_path='',
              index_key='',
              nprobe=128,
              verbose=True):
     import faiss
     self.verbose = verbose
     with Timer('[faiss] build index', verbose):
         if index_path != '' and os.path.exists(index_path):
             print('[faiss] read index from {}'.format(index_path))
             index = faiss.read_index(index_path)
         else:
             feats = feats.astype('float32')
             size, dim = feats.shape
             index = faiss.IndexFlatIP(dim)
             if index_key != '':
                 assert index_key.find(
                     'HNSW') < 0, 'HNSW returns distances insted of sims'
                 metric = faiss.METRIC_INNER_PRODUCT
                 nlist = min(4096, 8 * round(math.sqrt(size)))
                 if index_key == 'IVF':
                     quantizer = index
                     index = faiss.IndexIVFFlat(quantizer, dim, nlist,
                                                metric)
                 else:
                     index = faiss.index_factory(dim, index_key, metric)
                 if index_key.find('Flat') < 0:
                     assert not index.is_trained
                 index.train(feats)
                 index.nprobe = min(nprobe, nlist)
                 assert index.is_trained
                 print('nlist: {}, nprobe: {}'.format(nlist, nprobe))
             index.add(feats)
             if index_path != '':
                 print('[faiss] save index to {}'.format(index_path))
                 mkdir_if_no_exists(index_path)
                 faiss.write_index(index, index_path)
     with Timer('[faiss] query topk {}'.format(k), verbose):
         knn_ofn = index_path + '.npz'
         if os.path.exists(knn_ofn):
             print('[faiss] read knns from {}'.format(knn_ofn))
             self.knns = [(knn[0, :].astype(np.int32), knn[1, :].astype(np.float32)) \
                             for knn in np.load(knn_ofn)['data']]
         else:
             sims, ners = index.search(feats, k=k)
             self.knns = [(np.array(ner, dtype=np.int32), 1 - np.array(sim, dtype=np.float32)) \
                             for ner, sim in zip(ners, sims)]
예제 #11
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    # set cuda
    cfg.cuda = not args.no_cuda and torch.cuda.is_available()

    # set cudnn_benchmark & cudnn_deterministic
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    if cfg.get('cudnn_deterministic', False):
        torch.backends.cudnn.deterministic = True

    # update configs according to args
    if not hasattr(cfg, 'work_dir'):
        if args.work_dir is not None:
            cfg.work_dir = args.work_dir
        else:
            cfg_name = rm_suffix(os.path.basename(args.config))
            cfg.work_dir = os.path.join('./data/work_dir', cfg_name)
    mkdir_if_no_exists(cfg.work_dir, is_folder=True)
    if not hasattr(cfg, 'stage'):
        cfg.stage = args.stage

    if not hasattr(cfg, 'test_batch_size_per_gpu'):
        cfg.test_batch_size_per_gpu = cfg.batch_size_per_gpu

    cfg.load_from = args.load_from
    cfg.resume_from = args.resume_from

    cfg.pred_iou_score = args.pred_iou_score
    cfg.pred_iop_score = args.pred_iop_score

    cfg.gpus = args.gpus
    cfg.det_label = args.det_label
    cfg.distributed = args.distributed
    cfg.save_output = args.save_output

    logger = create_logger()

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model = build_model(cfg.model['type'], **cfg.model['kwargs'])
    handler = build_handler(args.phase, args.stage)

    handler(model, cfg, logger)
예제 #12
0
def get_output_path(args, ofn='pred_labels.txt'):
    method2name = {
        'approx_rank_order':
        'k_{}_th_{}'.format(args.knn, args.th_sim),
        'dbscan':
        'eps_{}_min_{}'.format(args.eps, args.min_samples),
        'knn_dbscan':
        'eps_{}_min_{}_k_{}_th_{}'.format(args.eps, args.min_samples, args.knn,
                                          args.th_sim),
        'hdbscan':
        'min_{}'.format(args.min_samples),
        'fast_hierarchy':
        'dist_{}_hmethod_{}'.format(args.distance, args.hmethod),
        'hierarchy':
        'n_{}_k_{}'.format(args.n_clusters, args.knn),
        'mini_batch_kmeans':
        'n_{}_bs_{}'.format(args.n_clusters, args.batch_size),
        'kmeans':
        'n_{}'.format(args.n_clusters),
        'spectral':
        'n_{}'.format(args.n_clusters),
    }

    if args.method in method2name:
        name = '{}_{}'.format(args.method, method2name[args.method])
    else:
        name = args.method

    opath = os.path.join(args.oprefix, name, ofn)
    if os.path.exists(opath) and not args.force:
        raise FileExistsError(
            '{} has already existed. Please set force=True to overwrite.'.
            format(opath))
    mkdir_if_no_exists(opath)

    return opath
예제 #13
0
def test_gcn_e(model, cfg, logger):
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)
    dataset = build_dataset(cfg.model['type'], cfg.test_data)

    pred_peaks = dataset.peaks
    pred_dist2peak = dataset.dist2peak

    ofn_pred = osp.join(cfg.work_dir, 'pred_conns.npz')
    if osp.isfile(ofn_pred) and not cfg.force:
        data = np.load(ofn_pred)
        pred_conns = data['pred_conns']
        inst_num = data['inst_num']
        if inst_num != dataset.inst_num:
            logger.warn(
                'instance number in {} is different from dataset: {} vs {}'.
                format(ofn_pred, inst_num, len(dataset)))
    else:
        if cfg.random_conns:
            pred_conns = []
            for nbr, dist, idx in zip(dataset.subset_nbrs,
                                      dataset.subset_dists,
                                      dataset.subset_idxs):
                for _ in range(cfg.max_conn):
                    pred_rel_nbr = np.random.choice(np.arange(len(nbr)))
                    pred_abs_nbr = nbr[pred_rel_nbr]
                    pred_peaks[idx].append(pred_abs_nbr)
                    pred_dist2peak[idx].append(dist[pred_rel_nbr])
                    pred_conns.append(pred_rel_nbr)
            pred_conns = np.array(pred_conns)
        else:
            pred_conns = test(model, dataset, cfg, logger)
            for pred_rel_nbr, nbr, dist, idx in zip(pred_conns,
                                                    dataset.subset_nbrs,
                                                    dataset.subset_dists,
                                                    dataset.subset_idxs):
                pred_abs_nbr = nbr[pred_rel_nbr]
                pred_peaks[idx].extend(pred_abs_nbr)
                pred_dist2peak[idx].extend(dist[pred_rel_nbr])
        inst_num = dataset.inst_num

    if len(pred_conns) > 0:
        logger.info(
            'pred_conns (nbr order): mean({:.1f}), max({}), min({})'.format(
                pred_conns.mean(), pred_conns.max(), pred_conns.min()))

    if not dataset.ignore_label and cfg.eval_interim:
        subset_gt_labels = dataset.subset_gt_labels
        for i in range(cfg.max_conn):
            pred_peaks_labels = np.array([
                dataset.idx2lb[pred_peaks[idx][i]]
                for idx in dataset.subset_idxs
            ])

            acc = accuracy(pred_peaks_labels, subset_gt_labels)
            logger.info(
                '[{}-th] accuracy of pred_peaks labels ({}): {:.4f}'.format(
                    i, len(pred_peaks_labels), acc))

            # the rule for nearest nbr is only appropriate when nbrs is sorted
            nearest_idxs = np.where(pred_conns[:, i] == 0)[0]
            acc = accuracy(pred_peaks_labels[nearest_idxs],
                           subset_gt_labels[nearest_idxs])
            logger.info(
                '[{}-th] accuracy of pred labels (nearest: {}): {:.4f}'.format(
                    i, len(nearest_idxs), acc))

            not_nearest_idxs = np.where(pred_conns[:, i] > 0)[0]
            acc = accuracy(pred_peaks_labels[not_nearest_idxs],
                           subset_gt_labels[not_nearest_idxs])
            logger.info(
                '[{}-th] accuracy of pred labels (not nearest: {}): {:.4f}'.
                format(i, len(not_nearest_idxs), acc))

    with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau)):
        pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau,
                                      inst_num)

    if cfg.save_output:
        logger.info(
            'save predicted connectivity and labels to {}'.format(ofn_pred))
        if not osp.isfile(ofn_pred) or cfg.force:
            np.savez_compressed(ofn_pred,
                                pred_conns=pred_conns,
                                inst_num=inst_num)

        # save clustering results
        idx2lb = list2dict(pred_labels, ignore_value=-1)

        folder = '{}_gcne_k_{}_th_{}_ig_{}'.format(cfg.test_name, cfg.knn,
                                                   cfg.th_sim,
                                                   cfg.test_data.ignore_ratio)
        opath_pred_labels = osp.join(cfg.work_dir, folder,
                                     'tau_{}_pred_labels.txt'.format(cfg.tau))
        mkdir_if_no_exists(opath_pred_labels)
        write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

    # evaluation
    if not dataset.ignore_label:
        print('==> evaluation')
        for metric in cfg.metrics:
            evaluate(dataset.gt_labels, pred_labels, metric)

        # H and C-scores
        gt_dict = {}
        pred_dict = {}
        for i in range(len(dataset.gt_labels)):
            gt_dict[str(i)] = dataset.gt_labels[i]
            pred_dict[str(i)] = pred_labels[i]
        bm = ClusteringBenchmark(gt_dict)
        scores = bm.evaluate_vmeasure(pred_dict)
        # fmi_scores = bm.evaluate_fowlkes_mallows_score(pred_dict)
        print(scores)
예제 #14
0
def test_gcn_v(model, cfg, logger):
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)
    dataset = build_dataset(cfg.model['type'], cfg.test_data)

    folder = '{}_gcnv_k_{}_th_{}'.format(cfg.test_name, cfg.knn, cfg.th_sim)
    oprefix = osp.join(cfg.work_dir, folder)
    oname = osp.basename(rm_suffix(cfg.load_from))
    opath_pred_confs = osp.join(oprefix, 'pred_confs', '{}.npz'.format(oname))

    if osp.isfile(opath_pred_confs) and not cfg.force:
        data = np.load(opath_pred_confs)
        pred_confs = data['pred_confs']
        inst_num = data['inst_num']
        if inst_num != dataset.inst_num:
            logger.warn(
                'instance number in {} is different from dataset: {} vs {}'.
                format(opath_pred_confs, inst_num, len(dataset)))
    else:
        pred_confs, gcn_feat = test(model, dataset, cfg, logger)
        inst_num = dataset.inst_num

    logger.info('pred_confs: mean({:.4f}). max({:.4f}), min({:.4f})'.format(
        pred_confs.mean(), pred_confs.max(), pred_confs.min()))

    logger.info('Convert to cluster')
    with Timer('Predition to peaks'):
        pred_dist2peak, pred_peaks = confidence_to_peaks(
            dataset.dists, dataset.nbrs, pred_confs, cfg.max_conn)

    if not dataset.ignore_label and cfg.eval_interim:
        # evaluate the intermediate results
        for i in range(cfg.max_conn):
            num = len(dataset.peaks)
            pred_peaks_i = np.arange(num)
            peaks_i = np.arange(num)
            for j in range(num):
                if len(pred_peaks[j]) > i:
                    pred_peaks_i[j] = pred_peaks[j][i]
                if len(dataset.peaks[j]) > i:
                    peaks_i[j] = dataset.peaks[j][i]
            acc = accuracy(pred_peaks_i, peaks_i)
            logger.info('[{}-th conn] accuracy of peak match: {:.4f}'.format(
                i + 1, acc))
            acc = 0.
            for idx, peak in enumerate(pred_peaks_i):
                acc += int(dataset.idx2lb[peak] == dataset.idx2lb[idx])
            acc /= len(pred_peaks_i)
            logger.info(
                '[{}-th conn] accuracy of peak label match: {:.4f}'.format(
                    i + 1, acc))

    with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau_0)):
        pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau_0,
                                      inst_num)

    if cfg.save_output:
        logger.info('save predicted confs to {}'.format(opath_pred_confs))
        mkdir_if_no_exists(opath_pred_confs)
        np.savez_compressed(opath_pred_confs,
                            pred_confs=pred_confs,
                            inst_num=inst_num)

        # save clustering results
        idx2lb = list2dict(pred_labels, ignore_value=-1)

        opath_pred_labels = osp.join(
            cfg.work_dir, folder, 'tau_{}_pred_labels.txt'.format(cfg.tau_0))
        logger.info('save predicted labels to {}'.format(opath_pred_labels))
        mkdir_if_no_exists(opath_pred_labels)
        write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

    # evaluation
    if not dataset.ignore_label:
        print('==> evaluation')
        for metric in cfg.metrics:
            evaluate(dataset.gt_labels, pred_labels, metric)

    if cfg.use_gcn_feat:
        # gcn_feat is saved to disk for GCN-E
        opath_feat = osp.join(oprefix, 'features', '{}.bin'.format(oname))
        if not osp.isfile(opath_feat) or cfg.force:
            mkdir_if_no_exists(opath_feat)
            write_feat(opath_feat, gcn_feat)

        name = rm_suffix(osp.basename(opath_feat))
        prefix = oprefix
        ds = BasicDataset(name=name,
                          prefix=prefix,
                          dim=cfg.model['kwargs']['nhid'],
                          normalize=True)
        ds.info()

        # use top embedding of GCN to rebuild the kNN graph
        with Timer('connect to higher confidence with use_gcn_feat'):
            knn_prefix = osp.join(prefix, 'knns', name)
            knns = build_knns(knn_prefix,
                              ds.features,
                              cfg.knn_method,
                              cfg.knn,
                              is_rebuild=True)
            dists, nbrs = knns2ordered_nbrs(knns)

            pred_dist2peak, pred_peaks = confidence_to_peaks(
                dists, nbrs, pred_confs, cfg.max_conn)
            pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau,
                                          inst_num)

        # save clustering results
        if cfg.save_output:
            oname_meta = '{}_gcn_feat'.format(name)
            opath_pred_labels = osp.join(
                oprefix, oname_meta, 'tau_{}_pred_labels.txt'.format(cfg.tau))
            mkdir_if_no_exists(opath_pred_labels)

            idx2lb = list2dict(pred_labels, ignore_value=-1)
            write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

        # evaluation

        if not dataset.ignore_label:
            print('==> evaluation')
            for metric in cfg.metrics:
                evaluate(dataset.gt_labels, pred_labels, metric)
        import json
        import os
        import pdb
        pdb.set_trace()
        img_labels = json.load(
            open(r'/home/finn/research/data/clustering_data/test_index.json',
                 'r',
                 encoding='utf-8'))
        import shutil
        output = r'/home/finn/research/data/clustering_data/mr_gcn_output'
        for label in set(pred_labels):
            if not os.path.exists(os.path.join(output, f'cluter_{label}')):
                os.mkdir(os.path.join(output, f'cluter_{label}'))
        for image in img_labels:
            shutil.copy2(
                image,
                os.path.join(
                    os.path.join(output,
                                 f'cluter_{pred_labels[img_labels[image]]}'),
                    os.path.split(image)[-1]))