コード例 #1
0
    def __init__(self,
                 scores,
                 new_golds=True,
                 thresholds=None,
                 gold_getter=None):
        """
        Pararmeters
        -----------
        scores : {Score}
            Mapping of scores in export
        new_golds : bool
            Flag to indicate whether to fetch gold labels from database
            or to use the gold labels already in score objects
        """
        if gold_getter is None:
            gold_getter = GoldGetter()
            gold_getter.all()
        self.gold_getter = gold_getter

        if new_golds:
            scores = self._init_golds(scores)

        self.scores = scores
        self._sorted_ids = sorted(scores, key=lambda id_: scores[id_].p)
        self.class_counts = ScoreStats.counts(self.sorted_scores)

        if thresholds is None:
            thresholds = self.find_thresholds(config.fpr, config.mdr)
        self.thresholds = thresholds

        self._stats = None
コード例 #2
0
    def test_wrapper_getter(self):
        gg = GoldGetter()
        gg._golds = {}
        gg.all()

        print(gg.getters)
        assert callable(gg.getters[0])
コード例 #3
0
ファイル: history.py プロジェクト: cpadavis/swap
    def __init__(self, history, gold_getter=None):
        """
        Parameters
        ----------
        history : {History}
            Mapping of Subject History to subject id
        """
        self.history = history

        if gold_getter is None:
            gold_getter = GoldGetter()
            gold_getter.all()

        self.gold_getter = gold_getter
コード例 #4
0
ファイル: control.py プロジェクト: drphilmarshall/swap
    def __init__(self, *args):
        """
            Initialize control

            Args:
                p0:              (Deprecated) prior subject probability
                epsilon:         (Deprecated) initial user score
        """
        if len(args) > 0:
            raise DeprecationWarning('p0 and epsilon now live in config')

        # Number of subjects with expert labels for a
        # test/train split
        self.gold_getter = GoldGetter()
        self.swap = None
コード例 #5
0
    def _run(self):
        gg = GoldGetter()
        gg.all()

        swap = self.init_swap()
        gi = GoldIterator(gg.golds, self.start, self.step)

        for n, golds in enumerate(gi):

            logger.info('Running trial %d with %d golds', n, len(golds))
            fake = 0
            for gold in golds.values():
                if gold == -1:
                    fake += 1
            logger.debug('Fake n golds: %d', fake)

            swap.set_gold_labels(golds)
            swap.process_changes()
            self.add_trial(randomex.Trial(n, golds, swap.score_export()))

            if len(golds) > self.end:
                break
コード例 #6
0
    def _run(self):
        gg = GoldGetter()
        swap = self.init_swap()
        n = 1
        for cv in range(*self.controversial):
            for cn in range(*self.consensus):
                if cv == 0 and cn == 0:
                    continue
                gg.reset()

                logger.info('\nRunning trial %d with cv=%d cn=%d',
                            n, cv, cn)
                if cv > 0:
                    gg.controversial(cv)
                if cn > 0:
                    gg.consensus(cn,)

                swap.set_gold_labels(gg.golds)
                swap.process_changes()
                self.add_trial(
                    self.Trial(cn, cv, gg.golds, swap.score_export()))

                n += 1
コード例 #7
0
    def _run(self):
        gg = GoldGetter()
        swap = self.init_swap()
        for n_golds in range(*self.num_golds):
            for n in range(self.num_trials):

                gg.reset()
                gg.random(n_golds)

                logger.debug('Running trial %d with %d golds', n, n_golds)
                logger.debug('Real n golds: %d' % len(gg.golds))
                fake = 0
                for gold in gg.golds.values():
                    if gold == -1:
                        fake += 1
                logger.debug('Fake n golds: %d' % fake)

                swap.set_gold_labels(gg.golds)
                swap.process_changes()
                self.add_trial(self.Trial(n, gg.golds, swap.score_export()))
コード例 #8
0
    def test_wrapper_golds_to_None(self):
        gg = GoldGetter()
        gg._golds = {}
        gg.all()

        assert gg._golds is None
コード例 #9
0
ファイル: swap.py プロジェクト: hughdickinson/hco-experiments
    def call(self, args):
        swap = None
        score_export = None

        if args.load:
            obj = self.load(args.load[0])

            if isinstance(obj, SWAP):
                swap = obj
                score_export = swap.score_export()
            elif isinstance(obj, ScoreExport):
                score_export = obj

        if args.scores_from_csv:
            fname = args.scores_from_csv[0]
            score_export = ScoreExport.from_csv(fname)

        if args.run:
            swap = self.run_swap(args)
            score_export = swap.score_export()

        if swap is not None:

            if args.save:
                manifest = self.manifest(swap, args)
                self.save(swap, self.f(args.save[0]), manifest)

            if args.subject:
                fname = self.f(args.subject[0])
                plots.traces.plot_subjects(swap.history_export(), fname)

            if args.user:
                fname = self.f(args.user[0])
                plots.plot_user_cm(swap, fname)

            # if args.utraces:
            #     fname = self.f(args.user[0])
            #     plots.traces.plot_user(swap, fname)

            if args.log:
                fname = self.f(args.log[0])
                write_log(swap, fname)

            if args.stats:
                s = swap.stats_str()
                print(s)
                logger.debug(s)

            if args.test:
                from swap.utils.golds import GoldGetter
                gg = GoldGetter()
                logger.debug('applying new gold labels')
                swap.set_gold_labels(gg.golds)
                swap.process_changes()
                logger.debug('done')

            if args.test_reorder:
                self.reorder_classifications(swap)

            if args.export_user_scores:
                fname = self.f(args.export_user_scores[0])
                self.export_user_scores(swap, fname)

        if score_export is not None:
            if args.save_scores:
                fname = self.f(args.save_scores[0])
                self.save(score_export, fname)

            if args.hist:
                fname = self.f(args.hist[0])
                plots.plot_class_histogram(fname, score_export)

            if args.dist:
                data = [s.getScore() for s in swap.subjects]
                plots.plot_pdf(data,
                               self.f(args.dist[0]),
                               swap,
                               cutoff=float(args.dist[1]))

            if args.presrec:
                fname = self.f(args.presrec[0])
                plots.distributions.sklearn_purity_completeness(
                    fname, score_export)

            if args.scores_to_csv:
                self.scores_to_csv(score_export, args.scores_to_csv[0])

        if args.diff:
            self.difference(args)

        if args.shell:
            import code
            code.interact(local=locals())

        return swap
コード例 #10
0
ファイル: control.py プロジェクト: drphilmarshall/swap
class Control:
    """
        Gets classifications from database and feeds them to SWAP
    """
    def __init__(self, *args):
        """
            Initialize control

            Args:
                p0:              (Deprecated) prior subject probability
                epsilon:         (Deprecated) initial user score
        """
        if len(args) > 0:
            raise DeprecationWarning('p0 and epsilon now live in config')

        # Number of subjects with expert labels for a
        # test/train split
        self.gold_getter = GoldGetter()
        self.swap = None

    def run(self, amount=None):
        """
        Process all classifications in DB with SWAP

        .. note::
            Iterates through the classification collection of the
            database and proccesss each classification one at a time
            in the order returned by the db.
            Parameters like max_batch_size are hard-coded.
            Prints status.
        """

        if amount is None:
            amount = DB().classifications.get_stats()
            amount = amount['first_classifications']

        self.init_swap()

        # get classifications
        cursor = self.get_classifications()

        # loop over classification cursor to process
        # classifications one at a time
        logger.info('Start: SWAP Processing %d classifications', amount)

        count = 0
        with progressbar.ProgressBar(max_value=amount) as bar:
            bar.update(count)
            # Loop over all classifications of the query
            # Note that the exact size of the query might be lower than
            # n_classifications if not all classifications are being queried
            for cl in cursor:
                # process classification in swap
                cl = Classification.generate(cl)
                self._delegate(cl)
                bar.update(count)
                count += 1

                if config.control.debug and count > config.control.amount:
                    break

        if config.back_update:
            logger.info('back_update active: processing changes')
            self.swap.process_changes()
        logger.info('done')

    def _delegate(self, cl):
        """
        Passes classification to SWAP

        Purpose is to allow subclasses to override how SWAP receives
        classifications

        Parameters
        ----------
        cl : Classification
            Classification being delegated
        """
        self.swap.classify(cl)

    def init_swap(self):
        """
        Create a new SWAP instance, also passes SWAP the appropriate
        gold labels.

        Returns
        -------
        SWAP
            SWAP
        """
        logger.debug('Initializing SWAP')
        if self.swap is None:
            swap = SWAP()
        else:
            swap = self.swap

        golds = self.get_gold_labels()
        swap.set_gold_labels(golds)

        self.swap = swap
        return swap

    def get_gold_labels(self):
        """
        Get the set of gold labels being used for this run
        """
        return self.gold_getter.golds

    @staticmethod
    def get_classifications():
        """
        Get the cursor containing classifications from db

        Returns
        -------
        swap.db.Cursor
            Cursor with classifications
        """
        return DB().classifications.getClassifications()

    def getSWAP(self):
        """
        Get the SWAP instance being used

        Returns
        -------
        SWAP
            SWAP
        """
        return self.swap

    def setSWAP(self, swap):
        """
        Set the SWAP object
        """
        self.swap = swap

    def reset(self):
        """
        Reset the gold getter and SWAP instances.

        Useful when running multiple subsequent instances of SWAP
        """
        self.swap = None
        self.gold_getter.reset()
コード例 #11
0
 def get_real_golds():
     """
     Fetch gold labels from database
     """
     logger.debug('Getting real gold labels from db')
     return GoldGetter().all()()
コード例 #12
0
    def call(self, args):
        swap = None
        scores = None

        if args.load:
            obj = self.load(args.load[0])

            if isinstance(obj, SWAP):
                swap = obj
                scores = swap.score_export()
            elif isinstance(obj, ScoreExport):
                scores = obj

        if args.scores_from_csv:
            fname = args.scores_from_csv[0]
            scores = ScoreExport.from_csv(fname)

        if args.run:
            swap = self.run_swap(args)
            scores = swap.score_export()

        if swap is not None:

            if args.save:
                manifest = self.manifest(swap, args)
                self.save(swap, self.f(args.save[0]), manifest)

            if args.log:
                fname = self.f(args.log[0])
                write_log(swap, fname)

            if args.stats:
                s = swap.stats_str()
                print(s)
                logger.debug(s)

            if args.test:
                from swap.utils.golds import GoldGetter
                gg = GoldGetter()
                logger.debug('applying new gold labels')
                swap.set_gold_labels(gg.golds)
                swap.process_changes()
                logger.debug('done')

            if args.test_reorder:
                self.reorder_classifications(swap)

            if args.export_user_scores:
                fname = self.f(args.export_user_scores[0])
                self.export_user_scores(swap, fname)

        if scores is not None:
            if args.save_scores:
                DB().subjects.save_scores(scores)

            if args.scores_to_csv:
                self.scores_to_csv(scores, args.scores_to_csv[0])

        self.plot(args, swap, scores)

        if args.shell:
            import code
            code.interact(local=locals())

        return swap