コード例 #1
0
ファイル: extract.py プロジェクト: Jim61C/learn-to-cluster
def extract_gcn_v(opath_feat,
                  opath_pred_confs,
                  data_name,
                  cfg,
                  write_gcn_feat=False):
    if osp.isfile(opath_feat) and osp.isfile(opath_pred_confs):
        print('{} and {} already exist.'.format(opath_feat, opath_pred_confs))
        return
    cfg.cuda = torch.cuda.is_available()

    logger = create_logger()

    model = build_model(cfg.model['type'], **cfg.model['kwargs'])

    for k, v in cfg.model['kwargs'].items():
        setattr(cfg[data_name], k, v)
    cfg[data_name].eval_interim = False

    dataset = build_dataset(cfg.model['type'], cfg[data_name])

    pred_confs, gcn_feat = test(model, dataset, cfg, logger)

    if not osp.exists(opath_pred_confs):
        logger.info('save predicted confs to {}'.format(opath_pred_confs))
        mkdir_if_no_exists(opath_pred_confs)
        np.savez_compressed(opath_pred_confs,
                            pred_confs=pred_confs,
                            inst_num=dataset.inst_num)

    if not osp.exists(opath_feat) and write_gcn_feat:
        logger.info('save gcn features to {}'.format(opath_feat))
        mkdir_if_no_exists(opath_feat)
        write_feat(opath_feat, gcn_feat)
コード例 #2
0
def train_gcn_v(model, cfg, logger):
    # prepare dataset
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.train_data, k, v)
    dataset = build_dataset(cfg.model['type'], cfg.train_data)

    # train
    if cfg.distributed:
        raise NotImplementedError
    else:
        _single_train(model, dataset, cfg, logger)
コード例 #3
0
def train_gcn_e(model, cfg, logger):
    # prepare data loaders
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.train_data, k, v)

    dataset = build_dataset(cfg.model['type'], cfg.train_data)
    data_loaders = [
        build_dataloader(dataset,
                         cfg.batch_size_per_gpu,
                         cfg.workers_per_gpu,
                         train=True,
                         shuffle=True)
    ]

    # train
    if cfg.distributed:
        raise NotImplementedError
    else:
        _single_train(model, data_loaders, cfg, logger)
コード例 #4
0
def test_gcn_e(model, cfg, logger):
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)
    dataset = build_dataset(cfg.model['type'], cfg.test_data)

    pred_peaks = dataset.peaks
    pred_dist2peak = dataset.dist2peak

    ofn_pred = osp.join(cfg.work_dir, 'pred_conns.npz')
    if osp.isfile(ofn_pred) and not cfg.force:
        data = np.load(ofn_pred)
        pred_conns = data['pred_conns']
        inst_num = data['inst_num']
        if inst_num != dataset.inst_num:
            logger.warn(
                'instance number in {} is different from dataset: {} vs {}'.
                format(ofn_pred, inst_num, len(dataset)))
    else:
        if cfg.random_conns:
            pred_conns = []
            for nbr, dist, idx in zip(dataset.subset_nbrs,
                                      dataset.subset_dists,
                                      dataset.subset_idxs):
                for _ in range(cfg.max_conn):
                    pred_rel_nbr = np.random.choice(np.arange(len(nbr)))
                    pred_abs_nbr = nbr[pred_rel_nbr]
                    pred_peaks[idx].append(pred_abs_nbr)
                    pred_dist2peak[idx].append(dist[pred_rel_nbr])
                    pred_conns.append(pred_rel_nbr)
            pred_conns = np.array(pred_conns)
        else:
            pred_conns = test(model, dataset, cfg, logger)
            for pred_rel_nbr, nbr, dist, idx in zip(pred_conns,
                                                    dataset.subset_nbrs,
                                                    dataset.subset_dists,
                                                    dataset.subset_idxs):
                pred_abs_nbr = nbr[pred_rel_nbr]
                pred_peaks[idx].extend(pred_abs_nbr)
                pred_dist2peak[idx].extend(dist[pred_rel_nbr])
        inst_num = dataset.inst_num

    if len(pred_conns) > 0:
        logger.info(
            'pred_conns (nbr order): mean({:.1f}), max({}), min({})'.format(
                pred_conns.mean(), pred_conns.max(), pred_conns.min()))

    if not dataset.ignore_label and cfg.eval_interim:
        subset_gt_labels = dataset.subset_gt_labels
        for i in range(cfg.max_conn):
            pred_peaks_labels = np.array([
                dataset.idx2lb[pred_peaks[idx][i]]
                for idx in dataset.subset_idxs
            ])

            acc = accuracy(pred_peaks_labels, subset_gt_labels)
            logger.info(
                '[{}-th] accuracy of pred_peaks labels ({}): {:.4f}'.format(
                    i, len(pred_peaks_labels), acc))

            # the rule for nearest nbr is only appropriate when nbrs is sorted
            nearest_idxs = np.where(pred_conns[:, i] == 0)[0]
            acc = accuracy(pred_peaks_labels[nearest_idxs],
                           subset_gt_labels[nearest_idxs])
            logger.info(
                '[{}-th] accuracy of pred labels (nearest: {}): {:.4f}'.format(
                    i, len(nearest_idxs), acc))

            not_nearest_idxs = np.where(pred_conns[:, i] > 0)[0]
            acc = accuracy(pred_peaks_labels[not_nearest_idxs],
                           subset_gt_labels[not_nearest_idxs])
            logger.info(
                '[{}-th] accuracy of pred labels (not nearest: {}): {:.4f}'.
                format(i, len(not_nearest_idxs), acc))

    with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau)):
        pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau,
                                      inst_num)

    if cfg.save_output:
        logger.info(
            'save predicted connectivity and labels to {}'.format(ofn_pred))
        if not osp.isfile(ofn_pred) or cfg.force:
            np.savez_compressed(ofn_pred,
                                pred_conns=pred_conns,
                                inst_num=inst_num)

        # save clustering results
        idx2lb = list2dict(pred_labels, ignore_value=-1)

        folder = '{}_gcne_k_{}_th_{}_ig_{}'.format(cfg.test_name, cfg.knn,
                                                   cfg.th_sim,
                                                   cfg.test_data.ignore_ratio)
        opath_pred_labels = osp.join(cfg.work_dir, folder,
                                     'tau_{}_pred_labels.txt'.format(cfg.tau))
        mkdir_if_no_exists(opath_pred_labels)
        write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

    # evaluation
    if not dataset.ignore_label:
        print('==> evaluation')
        for metric in cfg.metrics:
            evaluate(dataset.gt_labels, pred_labels, metric)

        # H and C-scores
        gt_dict = {}
        pred_dict = {}
        for i in range(len(dataset.gt_labels)):
            gt_dict[str(i)] = dataset.gt_labels[i]
            pred_dict[str(i)] = pred_labels[i]
        bm = ClusteringBenchmark(gt_dict)
        scores = bm.evaluate_vmeasure(pred_dict)
        # fmi_scores = bm.evaluate_fowlkes_mallows_score(pred_dict)
        print(scores)
コード例 #5
0
def test_gcn_v(model, cfg, logger):
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)
    dataset = build_dataset(cfg.model['type'], cfg.test_data)

    folder = '{}_gcnv_k_{}_th_{}'.format(cfg.test_name, cfg.knn, cfg.th_sim)
    oprefix = osp.join(cfg.work_dir, folder)
    oname = osp.basename(rm_suffix(cfg.load_from))
    opath_pred_confs = osp.join(oprefix, 'pred_confs', '{}.npz'.format(oname))

    if osp.isfile(opath_pred_confs) and not cfg.force:
        data = np.load(opath_pred_confs)
        pred_confs = data['pred_confs']
        inst_num = data['inst_num']
        if inst_num != dataset.inst_num:
            logger.warn(
                'instance number in {} is different from dataset: {} vs {}'.
                format(opath_pred_confs, inst_num, len(dataset)))
    else:
        pred_confs, gcn_feat = test(model, dataset, cfg, logger)
        inst_num = dataset.inst_num

    logger.info('pred_confs: mean({:.4f}). max({:.4f}), min({:.4f})'.format(
        pred_confs.mean(), pred_confs.max(), pred_confs.min()))

    logger.info('Convert to cluster')
    with Timer('Predition to peaks'):
        pred_dist2peak, pred_peaks = confidence_to_peaks(
            dataset.dists, dataset.nbrs, pred_confs, cfg.max_conn)

    if not dataset.ignore_label and cfg.eval_interim:
        # evaluate the intermediate results
        for i in range(cfg.max_conn):
            num = len(dataset.peaks)
            pred_peaks_i = np.arange(num)
            peaks_i = np.arange(num)
            for j in range(num):
                if len(pred_peaks[j]) > i:
                    pred_peaks_i[j] = pred_peaks[j][i]
                if len(dataset.peaks[j]) > i:
                    peaks_i[j] = dataset.peaks[j][i]
            acc = accuracy(pred_peaks_i, peaks_i)
            logger.info('[{}-th conn] accuracy of peak match: {:.4f}'.format(
                i + 1, acc))
            acc = 0.
            for idx, peak in enumerate(pred_peaks_i):
                acc += int(dataset.idx2lb[peak] == dataset.idx2lb[idx])
            acc /= len(pred_peaks_i)
            logger.info(
                '[{}-th conn] accuracy of peak label match: {:.4f}'.format(
                    i + 1, acc))

    with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau_0)):
        pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau_0,
                                      inst_num)

    if cfg.save_output:
        logger.info('save predicted confs to {}'.format(opath_pred_confs))
        mkdir_if_no_exists(opath_pred_confs)
        np.savez_compressed(opath_pred_confs,
                            pred_confs=pred_confs,
                            inst_num=inst_num)

        # save clustering results
        idx2lb = list2dict(pred_labels, ignore_value=-1)

        opath_pred_labels = osp.join(
            cfg.work_dir, folder, 'tau_{}_pred_labels.txt'.format(cfg.tau_0))
        logger.info('save predicted labels to {}'.format(opath_pred_labels))
        mkdir_if_no_exists(opath_pred_labels)
        write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

    # evaluation
    if not dataset.ignore_label:
        print('==> evaluation')
        for metric in cfg.metrics:
            evaluate(dataset.gt_labels, pred_labels, metric)

    if cfg.use_gcn_feat:
        # gcn_feat is saved to disk for GCN-E
        opath_feat = osp.join(oprefix, 'features', '{}.bin'.format(oname))
        if not osp.isfile(opath_feat) or cfg.force:
            mkdir_if_no_exists(opath_feat)
            write_feat(opath_feat, gcn_feat)

        name = rm_suffix(osp.basename(opath_feat))
        prefix = oprefix
        ds = BasicDataset(name=name,
                          prefix=prefix,
                          dim=cfg.model['kwargs']['nhid'],
                          normalize=True)
        ds.info()

        # use top embedding of GCN to rebuild the kNN graph
        with Timer('connect to higher confidence with use_gcn_feat'):
            knn_prefix = osp.join(prefix, 'knns', name)
            knns = build_knns(knn_prefix,
                              ds.features,
                              cfg.knn_method,
                              cfg.knn,
                              is_rebuild=True)
            dists, nbrs = knns2ordered_nbrs(knns)

            pred_dist2peak, pred_peaks = confidence_to_peaks(
                dists, nbrs, pred_confs, cfg.max_conn)
            pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau,
                                          inst_num)

        # save clustering results
        if cfg.save_output:
            oname_meta = '{}_gcn_feat'.format(name)
            opath_pred_labels = osp.join(
                oprefix, oname_meta, 'tau_{}_pred_labels.txt'.format(cfg.tau))
            mkdir_if_no_exists(opath_pred_labels)

            idx2lb = list2dict(pred_labels, ignore_value=-1)
            write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

        # evaluation

        if not dataset.ignore_label:
            print('==> evaluation')
            for metric in cfg.metrics:
                evaluate(dataset.gt_labels, pred_labels, metric)
        import json
        import os
        import pdb
        pdb.set_trace()
        img_labels = json.load(
            open(r'/home/finn/research/data/clustering_data/test_index.json',
                 'r',
                 encoding='utf-8'))
        import shutil
        output = r'/home/finn/research/data/clustering_data/mr_gcn_output'
        for label in set(pred_labels):
            if not os.path.exists(os.path.join(output, f'cluter_{label}')):
                os.mkdir(os.path.join(output, f'cluter_{label}'))
        for image in img_labels:
            shutil.copy2(
                image,
                os.path.join(
                    os.path.join(output,
                                 f'cluter_{pred_labels[img_labels[image]]}'),
                    os.path.split(image)[-1]))