Пример #1
0
def build_ffm_hand(train_test):
    print "Building ffm starts."
    col_idx_dict = {}
    for col in train_test.columns:
        vals = train_test[col].unique().tolist()
        col_idx_dict[col] = utils.list2dict(vals)

    processNum = 23
    cols = train_test.columns
    field_dict = utils.list2dict(cols)
    offset_dict = {}
    offset = 0
    for col in train_test.columns:
        offset_dict[col] = offset
        offset += len(col_idx_dict[col])

    split_size = len(cols) / 23
    pool = multiprocessing.Pool(processes=processNum)
    for i in xrange(23):
        if i != 22:
            col_list = cols[i * split_size:(i + 1) * split_size]
            cid = {
                key: col_idx_dict[key]
                for key in col_idx_dict.keys() if key in col_list
            }
            fd = {
                key: field_dict[key]
                for key in field_dict.keys() if key in col_list
            }
            od = {
                key: offset_dict[key]
                for key in offset_dict.keys() if key in col_list
            }
            pool.apply_async(onehot_col,
                             ([train_test[col_list], col_list, cid, fd, od], ))
        else:
            col_list = cols[i * split_size:]
            cid = {
                key: col_idx_dict[key]
                for key in col_idx_dict.keys() if key in col_list
            }
            fd = {
                key: field_dict[key]
                for key in field_dict.keys() if key in col_list
            }
            od = {
                key: offset_dict[key]
                for key in offset_dict.keys() if key in col_list
            }
            pool.apply_async(onehot_col,
                             ([train_test[col_list], col_list, cid, fd, od], ))

    pool.close()
    pool.join()
    load_ffm(train_test)
Пример #2
0
 def run_pos(self, question, pos, posfile, n):
     """
     Uses POS tag information to find the answer to a given question
     """
     qtoks = self.preprocess(question)
     qdict = utils.list2dict(qtoks)
     return self.get_colloc_words_pos(qdict, pos, posfile, self.windowSize, n)
Пример #3
0
 def load(self, src_file, readonly=True):
     config = cPickle.load(open(src_file,'rb'))
     self._readonly = readonly
     self._option = config['option']
     self.idx2class = config['idx2class']
     self.class2idx = utils.list2dict(config['idx2class'])
     return self
Пример #4
0
    def score_ents(self, ents, question, n):
        """
        Computes the confidence scores for each entity answer candidate to the given question.

        params
        ----
            ents: dict of key: entity answer string
                          value: tuple (string docid,
                                        string of words before & after the answer string)
            question: string of question to be answered
            n: number of answer candidates to be returned

        returns list of tuple (string of target entity token -- the answer candidate,
                               float of confidence score)
        """
        candidates = []
        qtoks = self.preprocess(question)
        qdict = utils.list2dict(qtoks)
        for (ent, (docid, colloc_words)) in ents.items():
            if not (ent in question):
                wtoks = self.preprocess(colloc_words)
                score = self.score_from_words(wtoks, qdict)
                candidates.append((score, docid + " " + ent))
        candidates.sort()
        candidates.reverse()
        return candidates[0:n]
Пример #5
0
 def load(self, src_file, readonly=True):
     config = cPickle.load(open(src_file,'rb'))
     self._option = config['option']
     self.fidx2ngram = config['fidx2ngram']
     self.ngram2fidx = utils.list2dict(config['fidx2ngram'])
     self._readonly=readonly
     self.feat_gen = self.parse_option(config['option'])
     return self
Пример #6
0
 def load(self, src_file, readonly=True):
     config = cPickle.load(open(src_file,'rb'))
     self._readonly = readonly
     self._option = config['option']
     self.idx2tok = config['idx2tok']
     self.tok2idx = utils.list2dict(config['idx2tok'])
     self.stemmer, self.stopword_remover = self.parse_option(config['option'])
     self.tokenizer = self.default_tokenizer
     return self
Пример #7
0
def listaccounts(minconf=''):
    fname = 'btc/listaccounts'
    params = {'minconf':minconf}
    ret,res = get(fname,params)
    if not ret: print(res); return []
    if res.status_code != 200: return []
    res_dict = str2dict(res.text)
    if res_dict['success'] == 'false' or \
       'data' not in res_dict.keys():
        return []
    return list(list2dict(res_dict['data']).keys())
Пример #8
0
def getransactions(account='', count=10, skips=0):
    fname = 'btc/listtransactions'
    params = {
                'account':account,
                'count':count,
                'from':skips
             }
    ret,res = post(fname,params)
    if not ret: print(res); return []
    if res.status_code != 200: return []
    res_dict = str2dict(res.text)
    if res_dict['success'] == 'false' or \
       'data' not in res_dict.keys(): return []
    result = list2dict(res_dict['data'])
    return result if isinstance(result,list) else [result]
Пример #9
0
    def run_colloc(self, question, dfile, n):
        """
        params
        ----
            question: string of question
            dfile: the original answer data file in gzip format
            n: the numer of top answer candidates to return

        returns the list of tuples (float of confidence score,
                                    string of docid + answer)
        """
        qdict = utils.list2dict(self.preprocess(question))
        top_n_cands = self.get_colloc_words(dfile, n, qdict)
        top_n_anses = self.shrink_answer_size(top_n_cands, qdict, self.answerSize)
        return top_n_anses
Пример #10
0
 def __init__(self):
     """
     windowSize: the window size for collocation consideration
     answerSize: number of words allowed for an answer
     stopWords: a dictionary of stopwords
     stemmer: nltk Lancaster Stemmer. Example stemming:
          original string:
          Stemming is funnier than a bummer says the sushi loving computer scientist, Bob's wife.
          stemmed string:
          stem is funny than a bum say the sush lov comput sci , bob ' s wif .
     """
     self.windowSize = 10
     self.answerSize = 10
     self.stopWords = utils.list2dict(corpus.stopwords.words('english'))
     self.stemmer = stem.LancasterStemmer()
def test_cluster_seg(model, cfg, logger):
    assert osp.isfile(cfg.pred_iou_score)

    if cfg.load_from:
        logger.info('load pretrained model from: {}'.format(cfg.load_from))
        load_checkpoint(model, cfg.load_from, strict=True, logger=logger)

    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)

    setattr(cfg.test_data, 'pred_iop_score', cfg.pred_iop_score)

    dataset = build_dataset(cfg.test_data)
    processor = build_processor(cfg.stage)

    inst_num = dataset.inst_num

    # read pred_scores from file and do sanity check
    d = np.load(cfg.pred_iou_score, allow_pickle=True)
    pred_scores = d['data']
    meta = d['meta'].item()
    assert inst_num == meta['tot_inst_num'], '{} vs {}'.format(
        inst_num, meta['tot_inst_num'])

    proposals = [fn_node for fn_node, _ in dataset.tot_lst]
    _proposals = []
    fn_node_pattern = '*_node.npz'
    for proposal_folder in meta['proposal_folders']:
        fn_clusters = sorted(
            glob.glob(osp.join(proposal_folder, fn_node_pattern)))
        _proposals.extend([fn_node for fn_node in fn_clusters])
    assert proposals == _proposals, '{} vs {}'.format(len(proposals),
                                                      len(_proposals))

    losses = []
    pred_outlier_scores = []
    stats = {'mean': []}

    if cfg.gpus == 1:
        data_loader = build_dataloader(dataset,
                                       processor,
                                       cfg.test_batch_size_per_gpu,
                                       cfg.workers_per_gpu,
                                       train=False)

        model = MMDataParallel(model, device_ids=range(cfg.gpus))
        if cfg.cuda:
            model.cuda()

        model.eval()
        for i, data in enumerate(data_loader):
            with torch.no_grad():
                output, loss = model(data, return_loss=True)
                losses += [loss.item()]
                if i % cfg.log_config.interval == 0:
                    if dataset.ignore_label:
                        logger.info('[Test] Iter {}/{}'.format(
                            i, len(data_loader)))
                    else:
                        logger.info('[Test] Iter {}/{}: Loss {:.4f}'.format(
                            i, len(data_loader), loss))
                if cfg.save_output:
                    output = F.softmax(output, dim=1)
                    output = output[:, 1, :]
                    scores = output.data.cpu().numpy()
                    pred_outlier_scores.extend(list(scores))
                    stats['mean'] += [scores.mean()]
    else:
        raise NotImplementedError

    if not dataset.ignore_label:
        avg_loss = sum(losses) / len(losses)
        logger.info('[Test] Overall Loss {:.4f}'.format(avg_loss))

    scores_mean = 1. * sum(stats['mean']) / len(stats['mean'])
    logger.info('mean of pred_outlier_scores: {:.4f}'.format(scores_mean))

    # save predicted scores
    if cfg.save_output:
        if cfg.load_from:
            fn = osp.basename(cfg.load_from)
        else:
            fn = 'random'
        opath = osp.join(cfg.work_dir, fn[:fn.rfind('.pth')] + '.npz')
        meta = {
            'tot_inst_num': inst_num,
            'proposal_folders': cfg.test_data.proposal_folders,
        }
        logger.info('dump pred_outlier_scores ({}) to {}'.format(
            len(pred_outlier_scores), opath))
        np.savez_compressed(opath, data=pred_outlier_scores, meta=meta)

    # post-process
    outlier_scores = {
        fn_node: outlier_score
        for (fn_node,
             _), outlier_score in zip(dataset.lst, pred_outlier_scores)
    }

    # de-overlap (w gcn-s)
    pred_labels_w_seg = deoverlap(pred_scores,
                                  proposals,
                                  inst_num,
                                  cfg.th_pos,
                                  cfg.th_iou,
                                  outlier_scores=outlier_scores,
                                  th_outlier=cfg.th_outlier,
                                  keep_outlier=cfg.keep_outlier)

    # de-overlap (wo gcn-s)
    pred_labels_wo_seg = deoverlap(pred_scores, proposals, inst_num,
                                   cfg.th_pos, cfg.th_iou)

    # save predicted labels
    if cfg.save_output:
        ofn_meta_w_seg = osp.join(cfg.work_dir, 'pred_labels_w_seg.txt')
        ofn_meta_wo_seg = osp.join(cfg.work_dir, 'pred_labels_wo_seg.txt')
        print('save predicted labels to {} and {}'.format(
            ofn_meta_w_seg, ofn_meta_wo_seg))
        pred_idx2lb_w_seg = list2dict(pred_labels_w_seg, ignore_value=-1)
        pred_idx2lb_wo_seg = list2dict(pred_labels_wo_seg, ignore_value=-1)
        write_meta(ofn_meta_w_seg, pred_idx2lb_w_seg, inst_num=inst_num)
        write_meta(ofn_meta_wo_seg, pred_idx2lb_wo_seg, inst_num=inst_num)

    # evaluation
    if not dataset.ignore_label:
        gt_labels = dataset.labels
        print('==> evaluation (with gcn-s)')
        for metric in cfg.metrics:
            evaluate(gt_labels, pred_labels_w_seg, metric)
        print('==> evaluation (without gcn-s)')
        for metric in cfg.metrics:
            evaluate(gt_labels, pred_labels_wo_seg, metric)
Пример #12
0
def test_gcn_e(model, cfg, logger):
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)
    dataset = build_dataset(cfg.model['type'], cfg.test_data)

    pred_peaks = dataset.peaks
    pred_dist2peak = dataset.dist2peak

    ofn_pred = osp.join(cfg.work_dir, 'pred_conns.npz')
    if osp.isfile(ofn_pred) and not cfg.force:
        data = np.load(ofn_pred)
        pred_conns = data['pred_conns']
        inst_num = data['inst_num']
        if inst_num != dataset.inst_num:
            logger.warn(
                'instance number in {} is different from dataset: {} vs {}'.
                format(ofn_pred, inst_num, len(dataset)))
    else:
        if cfg.random_conns:
            pred_conns = []
            for nbr, dist, idx in zip(dataset.subset_nbrs,
                                      dataset.subset_dists,
                                      dataset.subset_idxs):
                for _ in range(cfg.max_conn):
                    pred_rel_nbr = np.random.choice(np.arange(len(nbr)))
                    pred_abs_nbr = nbr[pred_rel_nbr]
                    pred_peaks[idx].append(pred_abs_nbr)
                    pred_dist2peak[idx].append(dist[pred_rel_nbr])
                    pred_conns.append(pred_rel_nbr)
            pred_conns = np.array(pred_conns)
        else:
            pred_conns = test(model, dataset, cfg, logger)
            for pred_rel_nbr, nbr, dist, idx in zip(pred_conns,
                                                    dataset.subset_nbrs,
                                                    dataset.subset_dists,
                                                    dataset.subset_idxs):
                pred_abs_nbr = nbr[pred_rel_nbr]
                pred_peaks[idx].extend(pred_abs_nbr)
                pred_dist2peak[idx].extend(dist[pred_rel_nbr])
        inst_num = dataset.inst_num

    if len(pred_conns) > 0:
        logger.info(
            'pred_conns (nbr order): mean({:.1f}), max({}), min({})'.format(
                pred_conns.mean(), pred_conns.max(), pred_conns.min()))

    if not dataset.ignore_label and cfg.eval_interim:
        subset_gt_labels = dataset.subset_gt_labels
        for i in range(cfg.max_conn):
            pred_peaks_labels = np.array([
                dataset.idx2lb[pred_peaks[idx][i]]
                for idx in dataset.subset_idxs
            ])

            acc = accuracy(pred_peaks_labels, subset_gt_labels)
            logger.info(
                '[{}-th] accuracy of pred_peaks labels ({}): {:.4f}'.format(
                    i, len(pred_peaks_labels), acc))

            # the rule for nearest nbr is only appropriate when nbrs is sorted
            nearest_idxs = np.where(pred_conns[:, i] == 0)[0]
            acc = accuracy(pred_peaks_labels[nearest_idxs],
                           subset_gt_labels[nearest_idxs])
            logger.info(
                '[{}-th] accuracy of pred labels (nearest: {}): {:.4f}'.format(
                    i, len(nearest_idxs), acc))

            not_nearest_idxs = np.where(pred_conns[:, i] > 0)[0]
            acc = accuracy(pred_peaks_labels[not_nearest_idxs],
                           subset_gt_labels[not_nearest_idxs])
            logger.info(
                '[{}-th] accuracy of pred labels (not nearest: {}): {:.4f}'.
                format(i, len(not_nearest_idxs), acc))

    with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau)):
        pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau,
                                      inst_num)

    if cfg.save_output:
        logger.info(
            'save predicted connectivity and labels to {}'.format(ofn_pred))
        if not osp.isfile(ofn_pred) or cfg.force:
            np.savez_compressed(ofn_pred,
                                pred_conns=pred_conns,
                                inst_num=inst_num)

        # save clustering results
        idx2lb = list2dict(pred_labels, ignore_value=-1)

        folder = '{}_gcne_k_{}_th_{}_ig_{}'.format(cfg.test_name, cfg.knn,
                                                   cfg.th_sim,
                                                   cfg.test_data.ignore_ratio)
        opath_pred_labels = osp.join(cfg.work_dir, folder,
                                     'tau_{}_pred_labels.txt'.format(cfg.tau))
        mkdir_if_no_exists(opath_pred_labels)
        write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

    # evaluation
    if not dataset.ignore_label:
        print('==> evaluation')
        for metric in cfg.metrics:
            evaluate(dataset.gt_labels, pred_labels, metric)

        # H and C-scores
        gt_dict = {}
        pred_dict = {}
        for i in range(len(dataset.gt_labels)):
            gt_dict[str(i)] = dataset.gt_labels[i]
            pred_dict[str(i)] = pred_labels[i]
        bm = ClusteringBenchmark(gt_dict)
        scores = bm.evaluate_vmeasure(pred_dict)
        # fmi_scores = bm.evaluate_fowlkes_mallows_score(pred_dict)
        print(scores)
def test_cluster_det(model, cfg, logger):
    if cfg.load_from:
        logger.info('load pretrained model from: {}'.format(cfg.load_from))
        load_checkpoint(model, cfg.load_from, strict=True, logger=logger)

    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)
    dataset = build_dataset(cfg.test_data)
    processor = build_processor(cfg.stage)

    losses = []
    pred_scores = []

    if cfg.gpus == 1:
        data_loader = build_dataloader(dataset,
                                       processor,
                                       cfg.test_batch_size_per_gpu,
                                       cfg.workers_per_gpu,
                                       train=False)

        model = MMDataParallel(model, device_ids=range(cfg.gpus))
        if cfg.cuda:
            model.cuda()

        model.eval()
        for i, data in enumerate(data_loader):
            with torch.no_grad():
                output, loss = model(data, return_loss=True)
                losses += [loss.item()]
                if i % cfg.log_config.interval == 0:
                    if dataset.ignore_label:
                        logger.info('[Test] Iter {}/{}'.format(
                            i, len(data_loader)))
                    else:
                        logger.info('[Test] Iter {}/{}: Loss {:.4f}'.format(
                            i, len(data_loader), loss))
                if cfg.save_output:
                    output = output.view(-1)
                    prob = output.data.cpu().numpy()
                    pred_scores.append(prob)
    else:
        raise NotImplementedError

    if not dataset.ignore_label:
        avg_loss = sum(losses) / len(losses)
        logger.info('[Test] Overall Loss {:.4f}'.format(avg_loss))

    # save predicted scores
    if cfg.save_output:
        if cfg.load_from:
            fn = os.path.basename(cfg.load_from)
        else:
            fn = 'random'
        opath = os.path.join(cfg.work_dir, fn[:fn.rfind('.pth')] + '.npz')
        meta = {
            'tot_inst_num': dataset.inst_num,
            'proposal_folders': cfg.test_data.proposal_folders,
        }
        print('dump pred_score to {}'.format(opath))
        pred_scores = np.concatenate(pred_scores).ravel()
        np.savez_compressed(opath, data=pred_scores, meta=meta)

    # de-overlap
    proposals = [fn_node for fn_node, _ in dataset.lst]
    pred_labels = deoverlap(pred_scores, proposals, dataset.inst_num,
                            cfg.th_pos, cfg.th_iou)

    # save predicted labels
    if cfg.save_output:
        ofn_meta = os.path.join(cfg.work_dir, 'pred_labels.txt')
        print('save predicted labels to {}'.format(ofn_meta))
        pred_idx2lb = list2dict(pred_labels, ignore_value=-1)
        write_meta(ofn_meta, pred_idx2lb, inst_num=dataset.inst_num)

    # evaluation
    if not dataset.ignore_label:
        print('==> evaluation')
        gt_labels = dataset.labels
        for metric in cfg.metrics:
            evaluate(gt_labels, pred_labels, metric)
Пример #14
0
def test_gcn_v(model, cfg, logger):
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)
    dataset = build_dataset(cfg.model['type'], cfg.test_data)

    folder = '{}_gcnv_k_{}_th_{}'.format(cfg.test_name, cfg.knn, cfg.th_sim)
    oprefix = osp.join(cfg.work_dir, folder)
    oname = osp.basename(rm_suffix(cfg.load_from))
    opath_pred_confs = osp.join(oprefix, 'pred_confs', '{}.npz'.format(oname))

    if osp.isfile(opath_pred_confs) and not cfg.force:
        data = np.load(opath_pred_confs)
        pred_confs = data['pred_confs']
        inst_num = data['inst_num']
        if inst_num != dataset.inst_num:
            logger.warn(
                'instance number in {} is different from dataset: {} vs {}'.
                format(opath_pred_confs, inst_num, len(dataset)))
    else:
        pred_confs, gcn_feat = test(model, dataset, cfg, logger)
        inst_num = dataset.inst_num

    logger.info('pred_confs: mean({:.4f}). max({:.4f}), min({:.4f})'.format(
        pred_confs.mean(), pred_confs.max(), pred_confs.min()))

    logger.info('Convert to cluster')
    with Timer('Predition to peaks'):
        pred_dist2peak, pred_peaks = confidence_to_peaks(
            dataset.dists, dataset.nbrs, pred_confs, cfg.max_conn)

    if not dataset.ignore_label and cfg.eval_interim:
        # evaluate the intermediate results
        for i in range(cfg.max_conn):
            num = len(dataset.peaks)
            pred_peaks_i = np.arange(num)
            peaks_i = np.arange(num)
            for j in range(num):
                if len(pred_peaks[j]) > i:
                    pred_peaks_i[j] = pred_peaks[j][i]
                if len(dataset.peaks[j]) > i:
                    peaks_i[j] = dataset.peaks[j][i]
            acc = accuracy(pred_peaks_i, peaks_i)
            logger.info('[{}-th conn] accuracy of peak match: {:.4f}'.format(
                i + 1, acc))
            acc = 0.
            for idx, peak in enumerate(pred_peaks_i):
                acc += int(dataset.idx2lb[peak] == dataset.idx2lb[idx])
            acc /= len(pred_peaks_i)
            logger.info(
                '[{}-th conn] accuracy of peak label match: {:.4f}'.format(
                    i + 1, acc))

    with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau_0)):
        pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau_0,
                                      inst_num)

    if cfg.save_output:
        logger.info('save predicted confs to {}'.format(opath_pred_confs))
        mkdir_if_no_exists(opath_pred_confs)
        np.savez_compressed(opath_pred_confs,
                            pred_confs=pred_confs,
                            inst_num=inst_num)

        # save clustering results
        idx2lb = list2dict(pred_labels, ignore_value=-1)

        opath_pred_labels = osp.join(
            cfg.work_dir, folder, 'tau_{}_pred_labels.txt'.format(cfg.tau_0))
        logger.info('save predicted labels to {}'.format(opath_pred_labels))
        mkdir_if_no_exists(opath_pred_labels)
        write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

    # evaluation
    if not dataset.ignore_label:
        print('==> evaluation')
        for metric in cfg.metrics:
            evaluate(dataset.gt_labels, pred_labels, metric)

    if cfg.use_gcn_feat:
        # gcn_feat is saved to disk for GCN-E
        opath_feat = osp.join(oprefix, 'features', '{}.bin'.format(oname))
        if not osp.isfile(opath_feat) or cfg.force:
            mkdir_if_no_exists(opath_feat)
            write_feat(opath_feat, gcn_feat)

        name = rm_suffix(osp.basename(opath_feat))
        prefix = oprefix
        ds = BasicDataset(name=name,
                          prefix=prefix,
                          dim=cfg.model['kwargs']['nhid'],
                          normalize=True)
        ds.info()

        # use top embedding of GCN to rebuild the kNN graph
        with Timer('connect to higher confidence with use_gcn_feat'):
            knn_prefix = osp.join(prefix, 'knns', name)
            knns = build_knns(knn_prefix,
                              ds.features,
                              cfg.knn_method,
                              cfg.knn,
                              is_rebuild=True)
            dists, nbrs = knns2ordered_nbrs(knns)

            pred_dist2peak, pred_peaks = confidence_to_peaks(
                dists, nbrs, pred_confs, cfg.max_conn)
            pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau,
                                          inst_num)

        # save clustering results
        if cfg.save_output:
            oname_meta = '{}_gcn_feat'.format(name)
            opath_pred_labels = osp.join(
                oprefix, oname_meta, 'tau_{}_pred_labels.txt'.format(cfg.tau))
            mkdir_if_no_exists(opath_pred_labels)

            idx2lb = list2dict(pred_labels, ignore_value=-1)
            write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

        # evaluation

        if not dataset.ignore_label:
            print('==> evaluation')
            for metric in cfg.metrics:
                evaluate(dataset.gt_labels, pred_labels, metric)
        import json
        import os
        import pdb
        pdb.set_trace()
        img_labels = json.load(
            open(r'/home/finn/research/data/clustering_data/test_index.json',
                 'r',
                 encoding='utf-8'))
        import shutil
        output = r'/home/finn/research/data/clustering_data/mr_gcn_output'
        for label in set(pred_labels):
            if not os.path.exists(os.path.join(output, f'cluter_{label}')):
                os.mkdir(os.path.join(output, f'cluter_{label}'))
        for image in img_labels:
            shutil.copy2(
                image,
                os.path.join(
                    os.path.join(output,
                                 f'cluter_{pred_labels[img_labels[image]]}'),
                    os.path.split(image)[-1]))
Пример #15
0
    def train(self, X_train, save_memory=False):
        """
        The training phase of the EUDTR model。

        Args:
            X_train: Train dataset
            save_memory: If True returns path to train data, otherwise to test data
        """

        train_torch_dataset = IndexedDatase(X_train, numpy.array(list(range(X_train.shape[0]))))
        train_generator = torch.utils.data.DataLoader(train_torch_dataset, batch_size=self.batch_size, shuffle=True)

        X_train = X_train.swapaxes(1,2)
        ks = KShape(n_clusters=self.n_clusters).fit(X_train)
        labels = ks.labels_
        X_train = X_train.swapaxes(1,2)

        sc_score = -1

        label2index = list2dict(labels)

        for epoch in range(self.epochs):

            epoch_start = time.time()
            self.encoder = self.encoder.train()
            self.encoder = self.encoder.train()

            for batch_num, batch in enumerate(train_generator):

                loss = 0
                indices, data = batch
                pos_samples, neg_samples = pos_neg_sampling(X_train, labels, label2index, indices, data, self.nb_random_samples)
                pos_samples = torch.from_numpy(pos_samples)
                neg_samples = torch.from_numpy(neg_samples).permute(1, 0, 2, 3)

                data = data.to(self.device)
                pos_samples = pos_samples.to(self.device)
                neg_samples = neg_samples.to(self.device)

                self.encoder_optimizer.zero_grad()
                self.decoder_optimizer.zero_grad()

                ref_embedding = self.encoder(data)
                pos_i_embedding = self.encoder(pos_samples)

                # Calculate the PN-Triplet loss and backward
                loss = -torch.mean(torch.nn.functional.logsigmoid(torch.bmm(
                        ref_embedding.view(data.shape[0], 1, self.out_channels),
                        pos_i_embedding.view(data.shape[0], self.out_channels, 1)
                        )))

                if save_memory:
                    loss.backward(retain_graph=True)
                    loss = 0
                    del pos_i_embedding
                    torch.cuda.empty_cache()

                multiplicative_ratio = self.negative_penalty / self.nb_random_samples

                for i in range(self.nb_random_samples):
                    neg_i_embedding = self.encoder(neg_samples[i])
                    loss += multiplicative_ratio * -torch.mean(torch.nn.functional.logsigmoid(-torch.bmm(
                        ref_embedding.view(data.shape[0], 1, self.out_channels),
                        neg_i_embedding.view(data.shape[0], self.out_channels, 1)
                        )))

                    if save_memory:
                        loss.backward(retain_graph=True)
                        loss = 0
                        del neg_i_embedding
                        torch.cuda.empty_cache()

                # Calculate the MI loss and backward
                self.mi_loss(data, neg_samples, self.encoder, self.decoder, save_memory)
        
                self.encoder_optimizer.step()
                self.decoder_optimizer.step()
         
            epoch_end = time.time()
            print('Train--Epoch: ', epoch + 1, " time: ", epoch_end - epoch_start)

            features = self.encode(X_train, self.batch_size)

            '''
            In order to speed up the convergence of the contour coefficients, 
            it is judged that the contour coefficients are not updated for 3 consecutive times.
            '''
            consecutive_failures = 0
            while True:
                km = KMeans(n_clusters=self.n_clusters).fit(features.reshape(features.shape[0], -1))
                temp_score = silhouette_score(features.reshape(features.shape[0], -1), km.labels_)
                if temp_score > sc_score:
                    consecutive_failures = 0
                    sc_score = temp_score
                    print('sc_score changed:',sc_score)
                    labels = km.labels_
                    label2index = list2dict(labels)
                else:
                    consecutive_failures = consecutive_failures + 1
                    if consecutive_failures == 3:
                        break