コード例 #1
0
ファイル: tags.py プロジェクト: arizvisa/idascripts
    def globals(Globals, **tagmap):
        '''Apply the tags in `Globals` back into the database.'''
        global apply
        cls, tagmap_output = apply.__class__, u", {:s}".format(u', '.join(u"{:s}={:s}".format(internal.utils.string.escape(oldtag), internal.utils.string.escape(newtag)) for oldtag, newtag in six.iteritems(tagmap))) if tagmap else ''

        count = 0
        for ea, res in Globals:
            ns = func if func.within(ea) else db

            # grab the current (old) tag state
            state = ns.tag(ea)

            # transform the new tag state using the tagmap
            new = { tagmap.get(name, name) : value for name, value in six.viewitems(res) }

            # check if the tag mapping resulted in the deletion of a tag
            if len(new) != len(res):
                for name in six.viewkeys(res) - six.viewkeys(new):
                    logging.warn(u"{:s}.globals(...{:s}) : Refusing requested tag mapping as it results in the tag \"{:s}\" overwriting the tag \"{:s}\" in the global {:#x}. The value {!s} would be replaced with {!s}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.escape(name, '"'), internal.utils.string.escape(tagmap[name], '"'), ea, internal.utils.string.repr(res[name]), internal.utils.string.repr(res[tagmap[name]])))
                pass

            # check what's going to be overwritten with different values prior to doing it
            for name in six.viewkeys(state) & six.viewkeys(new):
                if state[name] == new[name]: continue
                logging.warn(u"{:s}.globals(...{:s}) : Overwriting tag \"{:s}\" for global at {:#x} with new value {!s}. Old value was {!s}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.escape(name, '"'), ea, internal.utils.string.repr(new[name]), internal.utils.string.repr(state[name])))

            # now we can apply the tags to the global address
            try:
                [ ns.tag(ea, name, value) for name, value in six.iteritems(new) if state.get(name, dummy) != value ]
            except:
                logging.warn(u"{:s}.globals(...{:s}) : Unable to apply tags ({!s}) to global {:#x}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.repr(new), ea), exc_info=True)

            # increase our counter
            count += 1
        return count
コード例 #2
0
ファイル: tradesimulation.py プロジェクト: zhou/zipline
        def once_a_day(midnight_dt, current_data=self.current_data,
                       data_portal=self.data_portal):
            # process any capital changes that came overnight
            for capital_change in algo.calculate_capital_changes(
                    midnight_dt, emission_rate=emission_rate,
                    is_interday=True):
                yield capital_change

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            metrics_tracker.handle_market_open(
                midnight_dt,
                algo.data_portal,
            )

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = (
                viewkeys(metrics_tracker.positions) |
                viewkeys(algo.blotter.open_orders)
            )

            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    metrics_tracker.handle_splits(splits)
コード例 #3
0
ファイル: stat_ydensity.py プロジェクト: jwhendy/plotnine
    def setup_params(self, data):
        params = self.params.copy()

        valid_scale = ('area', 'count', 'width')
        if params['scale'] not in valid_scale:
            msg = "Parameter scale should be one of {}"
            raise PlotnineError(msg.format(valid_scale))

        lookup = {
            'biweight': 'biw',
            'cosine': 'cos',
            'cosine2': 'cos2',
            'epanechnikov': 'epa',
            'gaussian': 'gau',
            'triangular': 'tri',
            'triweight': 'triw',
            'uniform': 'uni'}

        with suppress(KeyError):
            params['kernel'] = lookup[params['kernel'].lower()]

        if params['kernel'] not in six.viewvalues(lookup):
            msg = ("kernel should be one of {}. "
                   "You may use the abbreviations {}")
            raise PlotnineError(msg.format(six.viewkeys(lookup),
                                           six.viewvalues()))

        missing_params = (six.viewkeys(stat_density.DEFAULT_PARAMS) -
                          six.viewkeys(params))
        for key in missing_params:
            params[key] = stat_density.DEFAULT_PARAMS[key]

        return params
コード例 #4
0
ファイル: periodic.py プロジェクト: viperf/magnum
    def sync_bay_status(self, ctx):
        try:
            LOG.debug('Starting to sync up bay status')
            osc = clients.OpenStackClients(ctx)
            status = [bay_status.CREATE_IN_PROGRESS,
                      bay_status.UPDATE_IN_PROGRESS,
                      bay_status.DELETE_IN_PROGRESS]
            filters = {'status': status}
            bays = objects.Bay.list(ctx, filters=filters)
            if not bays:
                return
            sid_to_bay_mapping = {bay.stack_id: bay for bay in bays}
            bay_stack_ids = sid_to_bay_mapping.keys()

            stacks = osc.heat().stacks.list(global_tenant=True,
                                            filters={'id': bay_stack_ids})
            sid_to_stack_mapping = {s.id: s for s in stacks}

            # intersection of bays magnum has and heat has
            for sid in (six.viewkeys(sid_to_bay_mapping) &
                        six.viewkeys(sid_to_stack_mapping)):
                stack = sid_to_stack_mapping[sid]
                bay = sid_to_bay_mapping[sid]
                self._sync_existing_bay(bay, stack)

            # the stacks that magnum has but heat doesn't have
            for sid in (six.viewkeys(sid_to_bay_mapping) -
                        six.viewkeys(sid_to_stack_mapping)):
                bay = sid_to_bay_mapping[sid]
                self._sync_missing_heat_stack(bay)

        except Exception as e:
            LOG.warn(_LW("Ignore error [%s] when syncing up bay status."), e,
                     exc_info=True)
コード例 #5
0
def _dictionary_merge(dictionary_left, dictionary_right):
    """Merge two dictionaries preserving values for the same key.

    :param dictionary_left:
        A valid dictionary with keys and values.

        Example:
            dictionary_left = {1: 'A', 2: ['B', 'C'], 3: []}

    :param dictionary_right:
        A valid dictionary with keys and values.

        Example:
            dictionary_right = {1: 'A', 2: 'C', 4: 'E'}

    :return:
        A merged dictionary, which preserves both values in the situation
        of a key conflict.

        Example:
            {1: ['A', 'A'], 2: [['B', 'C'], 'C'], 3: [[]], 4: ['E']}
    """
    merged_dictionary = {}

    for key in (viewkeys(dictionary_left) | viewkeys(dictionary_right)):
        if key in dictionary_left:
            merged_dictionary.setdefault(key, []).append(dictionary_left[key])
        if key in dictionary_right:
            merged_dictionary.setdefault(key, []).append(
                dictionary_right[key])

    return merged_dictionary
コード例 #6
0
ファイル: tradesimulation.py プロジェクト: zhoukalex/catalyst
        def once_a_day(midnight_dt, current_data=self.current_data,
                       data_portal=self.data_portal):

            perf_tracker = algo.perf_tracker

            # Get the positions before updating the date so that prices are
            # fetched for trading close instead of midnight
            positions = algo.perf_tracker.position_tracker.positions
            position_assets = algo.asset_finder.retrieve_all(positions)

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            # process any capital changes that came overnight
            for capital_change in algo.calculate_capital_changes(
                    midnight_dt, emission_rate=emission_rate,
                    is_interday=True):
                yield capital_change

            # we want to wait until the clock rolls over to the next day
            # before cleaning up expired assets.
            self._cleanup_expired_assets(midnight_dt, position_assets)

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = \
                viewkeys(perf_tracker.position_tracker.positions) | \
                viewkeys(algo.blotter.open_orders)

            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    perf_tracker.position_tracker.handle_splits(splits)
コード例 #7
0
ファイル: stat.py プロジェクト: jwhendy/plotnine
    def use_defaults(self, data):
        """
        Combine data with defaults and set aesthetics from parameters

        stats should not override this method.

        Parameters
        ----------
        data : dataframe
            Data used for drawing the geom.

        Returns
        -------
        out : dataframe
            Data used for drawing the geom.
        """
        missing = (self.aesthetics() -
                   six.viewkeys(self.aes_params) -
                   set(data.columns))

        for ae in missing-self.REQUIRED_AES:
            if self.DEFAULT_AES[ae] is not None:
                data[ae] = self.DEFAULT_AES[ae]

        missing = (six.viewkeys(self.aes_params) -
                   set(data.columns))

        for ae in self.aes_params:
            data[ae] = self.aes_params[ae]

        return data
コード例 #8
0
ファイル: guide.py プロジェクト: jwhendy/plotnine
    def legend_aesthetics(self, layer, plot):
        """
        Return the aesthetics that contribute to the legend

        Parameters
        ----------
        layer : Layer
            Layer whose legend is to be drawn
        plot : ggplot
            Plot object

        Returns
        -------
        matched : list
            List of the names of the aethetics that contribute
            to the legend.
        """
        l = layer
        legend_ae = set(self.key.columns) - {'label'}
        all_ae = (six.viewkeys(l.mapping) |
                  (plot.mapping if l.inherit_aes else set()) |
                  six.viewkeys(l.stat.DEFAULT_AES))
        geom_ae = l.geom.REQUIRED_AES | six.viewkeys(l.geom.DEFAULT_AES)
        matched = all_ae & geom_ae & legend_ae
        matched = list(matched - set(l.geom.aes_params))
        return matched
コード例 #9
0
ファイル: tradesimulation.py プロジェクト: UpSea/zipline
        def once_a_day(midnight_dt):
            # Get the positions before updating the date so that prices are
            # fetched for trading close instead of midnight
            positions = algo.perf_tracker.position_tracker.positions
            position_assets = algo.asset_finder.retrieve_all(positions)

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            # we want to wait until the clock rolls over to the next day
            # before cleaning up expired assets.
            self._cleanup_expired_assets(midnight_dt, position_assets)

            perf_tracker = algo.perf_tracker

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = \
                viewkeys(perf_tracker.position_tracker.positions) | \
                viewkeys(algo.blotter.open_orders)

            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    perf_tracker.position_tracker.handle_splits(splits)

            # call before trading start
            algo.before_trading_start(current_data)
コード例 #10
0
ファイル: segmentation.py プロジェクト: I2Cvb/isic-archive
    def validate(self, doc):
        return doc
        try:
            assert set(six.viewkeys(doc)) == {
                '_id', 'imageId', 'skill', 'creatorId', 'lesionBoundary',
                'created'}

            assert isinstance(doc['imageId'], ObjectId)
            assert self.model('image', 'isic_archive').find(
                {'_id': doc['imageId']}).count()

            # TODO: better use of Enum
            assert doc['skill'] in {'novice', 'expert'}

            assert isinstance(doc['creatorId'], ObjectId)
            assert self.model('user').find(
                {'_id': doc['creatorId']}).count()

            assert isinstance(doc['lesionBoundary'], dict)
            assert set(six.viewkeys(doc['lesionBoundary'])) == {
                'type', 'properties', 'geometry'}

            assert doc['lesionBoundary']['type'] == 'Feature'

            assert isinstance(doc['lesionBoundary']['properties'], dict)
            assert set(six.viewkeys(doc['lesionBoundary']['properties'])) <= {
                'source', 'startTime', 'stopTime', 'seedPoint', 'tolerance'}
            assert set(six.viewkeys(doc['lesionBoundary']['properties'])) >= {
                'source', 'startTime', 'stopTime'}
            assert doc['lesionBoundary']['properties']['source'] in {
                'autofill', 'manual pointlist'}
            assert isinstance(doc['lesionBoundary']['properties']['startTime'],
                              datetime.datetime)
            assert isinstance(doc['lesionBoundary']['properties']['stopTime'],
                              datetime.datetime)

            assert isinstance(doc['lesionBoundary']['geometry'], dict)
            assert set(six.viewkeys(doc['lesionBoundary']['geometry'])) == {
                'type', 'coordinates'}
            assert doc['lesionBoundary']['geometry']['type'] == 'Polygon'
            assert isinstance(doc['lesionBoundary']['geometry']['coordinates'],
                              list)
            assert len(doc['lesionBoundary']['geometry']['coordinates']) == 1
            assert isinstance(
                doc['lesionBoundary']['geometry']['coordinates'][0], list)
            assert len(doc['lesionBoundary']['geometry']['coordinates'][0]) > 2
            assert doc['lesionBoundary']['geometry']['coordinates'][0][0] == \
                doc['lesionBoundary']['geometry']['coordinates'][0][-1]
            for coord in doc['lesionBoundary']['geometry']['coordinates'][0]:
                assert isinstance(coord, list)
                assert len(coord) == 2
                assert isinstance(coord[0], (int, float))
                assert isinstance(coord[1], (int, float))

            assert isinstance(doc['created'], datetime.datetime)

        except AssertionError:
            # TODO: message
            raise ValidationException('')
        return doc
コード例 #11
0
ファイル: periodic.py プロジェクト: nttdata-osscloud/magnum
    def sync_bay_status(self, ctx):
        try:
            LOG.debug('Starting to sync up bay status')
            osc = clients.OpenStackClients(ctx)
            filters = [bay_status.CREATE_IN_PROGRESS,
                       bay_status.UPDATE_IN_PROGRESS,
                       bay_status.DELETE_IN_PROGRESS]
            bays = objects.Bay.list_all(ctx, filters=filters)
            if not bays:
                return
            sid_to_bay_mapping = {bay.stack_id: bay for bay in bays}
            bay_stack_ids = sid_to_bay_mapping.keys()

            stacks = osc.heat().stacks.list(global_tenant=True,
                                            filters={'id': bay_stack_ids})
            sid_to_stack_mapping = {s.id: s for s in stacks}

            for sid in (six.viewkeys(sid_to_bay_mapping) &
                        six.viewkeys(sid_to_stack_mapping)):
                stack = sid_to_stack_mapping[sid]
                bay = sid_to_bay_mapping[sid]
                if bay.status != stack.stack_status:
                    old_status = bay.status
                    bay.status = stack.stack_status
                    bay.save()
                    LOG.info(_LI("Sync up bay with id %(id)s from "
                                 "%(old_status)s to %(status)s."),
                             {'id': bay.id, 'old_status': old_status,
                              'status': bay.status})

            for sid in (six.viewkeys(sid_to_bay_mapping) -
                        six.viewkeys(sid_to_stack_mapping)):
                bay = sid_to_bay_mapping[sid]
                if bay.status == bay_status.DELETE_IN_PROGRESS:
                    bay.destroy()
                    LOG.info(_LI("Bay with id %(id)s has been deleted due "
                                 "to stack with id %(sid)s not found in "
                                 "Heat."),
                             {'id': bay.id, 'sid': sid})
                elif bay.status == bay_status.CREATE_IN_PROGRESS:
                    bay.status = bay_status.CREATE_FAILED
                    bay.save()
                    LOG.info(_LI("Bay with id %(id)s has been set to "
                                 "%(status)s due to stack with id %(sid)s "
                                 "not found in Heat."),
                             {'id': bay.id, 'status': bay.status,
                              'sid': sid})
                elif bay.status == bay_status.UPDATE_IN_PROGRESS:
                    bay.status = bay_status.UPDATE_FAILED
                    bay.save()
                    LOG.info(_LI("Bay with id %(id)s has been set to "
                                 "%(status)s due to stack with id %(sid)s "
                                 "not found in Heat."),
                             {'id': bay.id, 'status': bay.status,
                              'sid': sid})

        except Exception as e:
            LOG.warn(_LW("Ignore error [%s] when syncing up bay status."), e,
                     exc_info=True)
コード例 #12
0
ファイル: __init__.py プロジェクト: EdDev/vdsm
 def _clean_running_config_from_removed_nets(self):
     # Cleanup running config from networks that have been actually
     # removed but not yet removed from running config.
     running_config = RunningConfig()
     nets2remove = (six.viewkeys(running_config.networks) -
                    six.viewkeys(self.runningConfig.networks))
     for net in nets2remove:
         running_config.removeNetwork(net)
     running_config.save()
コード例 #13
0
ファイル: agents_db.py プロジェクト: igordcard/neutron
 def _adjust_az_filters(self, filters):
     # The intersect of sets gets us applicable filter keys (others ignored)
     common_keys = six.viewkeys(filters) & six.viewkeys(AZ_ATTRIBUTE_MAP)
     for key in common_keys:
         filter_key = AZ_ATTRIBUTE_MAP[key]['agent_key']
         filter_vals = filters.pop(key)
         if filter_vals:
             filter_vals = [AZ_ATTRIBUTE_MAP[key]['convert_to'](v)
                            for v in filter_vals]
         filters.setdefault(filter_key, [])
         filters[filter_key] += filter_vals
     return filters
コード例 #14
0
ファイル: predicates.py プロジェクト: quantopian/zipline
def assert_dict_equal(result, expected, path=(), msg="", **kwargs):
    _check_sets(viewkeys(result), viewkeys(expected), msg, path + (".%s()" % ("viewkeys" if PY2 else "keys"),), "key")

    failures = []
    for k, (resultv, expectedv) in iteritems(dzip_exact(result, expected)):
        try:
            assert_equal(resultv, expectedv, path=path + ("[%r]" % k,), msg=msg, **kwargs)
        except AssertionError as e:
            failures.append(str(e))

    if failures:
        raise AssertionError("\n".join(failures))
コード例 #15
0
ファイル: config.py プロジェクト: arizvisa/syringe
def namespace(cls):
    # turn all instances of things into read-only attributes
    attrs,properties,subclass = {},{},{}
    for k,v in cls.__dict__.items():
        if hasattr(v, '__name__'):
            v.__name__ = '{}.{}'.format(cls.__name__,k)
        if k.startswith('_') or type(v) is property:
            attrs[k] = v
        elif not six.callable(v) or isinstance(v,type):
            properties[k] = v
        elif not hasattr(v, '__class__'):
            subclass[k] = namespace(v)
        else:
            attrs[k] = v
        continue

    def getprops(obj):
        result = []
        col1,col2 = 0,0
        for k,v in obj.items():
            col1 = max((col1,len(k)))
            if isinstance(v, type):
                val = '<>'
            elif hasattr(v, '__class__'):
                val = '{!r}'.format(v)
            else:
                raise ValueError(k)
            doc = v.__doc__.split('\n')[0] if v.__doc__ else None
            col2 = max((col2,len(val)))
            result.append((k, val, doc))
        return [('{name:{}} : {val:{}} # {doc}' if d else '{name:{}} : {val:{}}').format(col1,col2,name=k,val=v,doc=d) for k,v,d in result]

    def __repr__(self):
        props = getprops(properties)
        descr = ('{{{!s}}} # {}\n' if cls.__doc__ else '{{{!s}}}\n')
        subs = ['{{{}.{}}}\n...'.format(cls.__name__,k) for k in subclass.keys()]
        res = descr.format(cls.__name__,cls.__doc__) + '\n'.join(props)
        if subs:
            return res + '\n' + '\n'.join(subs) + '\n'
        return res + '\n'

    def __setattr__(self, name, value):
        if name in six.viewkeys(attrs):
            object.__setattr__(self, name, value)
            return
        raise AttributeError('Configuration \'{:s}\' does not have field named \'{:s}\''.format(cls.__name__,name))

    attrs['__repr__'] = __repr__
    attrs['__setattr__'] = __setattr__
    attrs.update((k,property(fget=lambda s,k=k:properties[k])) for k in six.viewkeys(properties))
    attrs.update((k,property(fget=lambda s,k=k:subclass[k])) for k in six.viewkeys(subclass))
    result = type(cls.__name__, cls.__bases__, attrs)
    return result()
コード例 #16
0
ファイル: _bibfiles_db.py プロジェクト: clld/glottolog
def distance(left, right, weight={'author': 3, 'year': 3, 'title': 3, 'ENTRYTYPE': 2}):
    """Simple measure of the difference between two bibtex-field dicts."""
    if not (left or right):
        return 0.0

    keys = viewkeys(left) & viewkeys(right)
    if not keys:
        return 1.0

    weights = {k: weight.get(k, 1) for k in keys}
    ratios = (
        w * difflib.SequenceMatcher(None, left[k], right[k]).ratio()
        for k, w in weights.items())
    return 1 - (sum(ratios) / sum(weights.values()))
コード例 #17
0
ファイル: stat.py プロジェクト: jwhendy/plotnine
 def __init__(self, *args, **kwargs):
     kwargs = data_mapping_as_kwargs(args, kwargs)
     self._kwargs = kwargs  # Will be used to create the geom
     self.params = copy_keys(kwargs, deepcopy(self.DEFAULT_PARAMS))
     self.aes_params = {ae: kwargs[ae]
                        for ae in (self.aesthetics() &
                                   six.viewkeys(kwargs))}
コード例 #18
0
ファイル: exp4p.py プロジェクト: ntucllab/striatum
    def _exp4p_score(self, context):
        """The main part of Exp4.P.
        """
        advisor_ids = list(six.viewkeys(context))

        w = self._modelstorage.get_model()['w']
        if len(w) == 0:
            for i in advisor_ids:
                w[i] = 1
        w_sum = sum(six.viewvalues(w))

        action_probs_list = []
        for action_id in self.action_ids:
            weighted_exp = [w[advisor_id] * context[advisor_id][action_id]
                            for advisor_id in advisor_ids]
            prob_vector = np.sum(weighted_exp) / w_sum
            action_probs_list.append((1 - self.n_actions * self.p_min)
                                     * prob_vector
                                     + self.p_min)
        action_probs_list = np.asarray(action_probs_list)
        action_probs_list /= action_probs_list.sum()

        estimated_reward = {}
        uncertainty = {}
        score = {}
        for action_id, action_prob in zip(self.action_ids, action_probs_list):
            estimated_reward[action_id] = action_prob
            uncertainty[action_id] = 0
            score[action_id] = action_prob
        self._modelstorage.save_model(
            {'action_probs': estimated_reward, 'w': w})

        return estimated_reward, uncertainty, score
コード例 #19
0
ファイル: oauth_test.py プロジェクト: anukat2015/girder
 def getProviderResp():
     resp = self.request('/oauth/provider', params={
         'redirect': 'http://localhost/#foo/bar',
         'list': True
     })
     self.assertStatusOk(resp)
     self.assertIsInstance(resp.json, list)
     self.assertEqual(len(resp.json), 1)
     providerResp = resp.json[0]
     self.assertSetEqual(
         set(six.viewkeys(providerResp)),
         {'id', 'name', 'url'})
     self.assertEqual(providerResp['id'], providerInfo['id'])
     self.assertEqual(providerResp['name'], providerInfo['name'])
     self.assertRegexpMatches(
         providerResp['url'],
         providerInfo['url_re'])
     redirectParams = urllib.parse.parse_qs(
         urllib.parse.urlparse(providerResp['url']).query)
     csrfTokenParts = redirectParams['state'][0].partition('.')
     token = self.model('token').load(
         csrfTokenParts[0], force=True, objectId=False)
     self.assertLess(
         token['expires'],
         datetime.datetime.utcnow() + datetime.timedelta(days=0.30))
     self.assertEqual(
         csrfTokenParts[2],
         'http://localhost/#foo/bar')
     return providerResp
コード例 #20
0
ファイル: linthompsamp.py プロジェクト: ntucllab/striatum
    def _linthompsamp_score(self, context):
        """Thompson Sampling"""
        action_ids = list(six.viewkeys(context))
        context_array = np.asarray([context[action_id]
                                    for action_id in action_ids])
        model = self._model_storage.get_model()
        B = model['B']  # pylint: disable=invalid-name
        mu_hat = model['mu_hat']
        v = self.R * np.sqrt(24 / self.epsilon
                             * self.context_dimension
                             * np.log(1 / self.delta))
        mu_tilde = self.random_state.multivariate_normal(
            mu_hat.flat, v**2 * np.linalg.inv(B))[..., np.newaxis]
        estimated_reward_array = context_array.dot(mu_hat)
        score_array = context_array.dot(mu_tilde)

        estimated_reward_dict = {}
        uncertainty_dict = {}
        score_dict = {}
        for action_id, estimated_reward, score in zip(
                action_ids, estimated_reward_array, score_array):
            estimated_reward_dict[action_id] = float(estimated_reward)
            score_dict[action_id] = float(score)
            uncertainty_dict[action_id] = float(score - estimated_reward)
        return estimated_reward_dict, uncertainty_dict, score_dict
コード例 #21
0
ファイル: describe.py プロジェクト: jbeezley/girder
 def listResources(self, params):
     return {
         'apiVersion': API_VERSION,
         'swaggerVersion': SWAGGER_VERSION,
         'apis': [{'path': '/{}'.format(resource)}
                  for resource in sorted(six.viewkeys(docs.routes))]
     }
コード例 #22
0
ファイル: us_equity_pricing.py プロジェクト: mannau/zipline
    def write_csvs(self,
                   asset_map,
                   show_progress=False,
                   invalid_data_behavior='warn'):
        """Read CSVs as DataFrames from our asset map.

        Parameters
        ----------
        asset_map : dict[int -> str]
            A mapping from asset id to file path with the CSV data for that
            asset
        show_progress : bool
            Whether or not to show a progress bar while writing.
        invalid_data_behavior : {'warn', 'raise', 'ignore'}
            What to do when data is encountered that is outside the range of
            a uint32.
        """
        read = partial(
            read_csv,
            parse_dates=['day'],
            index_col='day',
            dtype=self._csv_dtypes,
        )
        return self.write(
            ((asset, read(path)) for asset, path in iteritems(asset_map)),
            assets=viewkeys(asset_map),
            show_progress=show_progress,
            invalid_data_behavior=invalid_data_behavior,
        )
コード例 #23
0
ファイル: momentum_pipeline.py プロジェクト: 4ever911/zipline
def rebalance(context, data):

    # Pipeline data will be a dataframe with boolean columns named 'longs' and
    # 'shorts'.
    pipeline_data = context.pipeline_data
    all_assets = pipeline_data.index

    longs = all_assets[pipeline_data.longs]
    shorts = all_assets[pipeline_data.shorts]

    record(universe_size=len(all_assets))

    # Build a 2x-leveraged, equal-weight, long-short portfolio.
    one_third = 1.0 / 3.0
    for asset in longs:
        order_target_percent(asset, one_third)

    for asset in shorts:
        order_target_percent(asset, -one_third)

    # Remove any assets that should no longer be in our portfolio.
    portfolio_assets = longs | shorts
    positions = context.portfolio.positions
    for asset in viewkeys(positions) - set(portfolio_assets):
        # This will fail if the asset was removed from our portfolio because it
        # was delisted.
        if data.can_trade(asset):
            order_target_percent(asset, 0)
コード例 #24
0
    def _decorator(f):
        args, varargs, varkw, defaults = argspec = getargspec(f)
        if defaults is None:
            defaults = ()
        no_defaults = (NO_DEFAULT,) * (len(args) - len(defaults))
        args_defaults = list(zip(args, no_defaults + defaults))
        if varargs:
            args_defaults.append((varargs, NO_DEFAULT))
        if varkw:
            args_defaults.append((varkw, NO_DEFAULT))

        argset = set(args) | {varargs, varkw} - {None}

        # Arguments can be declared as tuples in Python 2.
        if not all(isinstance(arg, str) for arg in args):
            raise TypeError(
                "Can't validate functions using tuple unpacking: %s" %
                (argspec,)
            )

        # Ensure that all processors map to valid names.
        bad_names = viewkeys(processors) - argset
        if bad_names:
            raise TypeError(
                "Got processors for unknown arguments: %s." % bad_names
            )

        return _build_preprocessed_function(
            f, processors, args_defaults, varargs, varkw,
        )
コード例 #25
0
ファイル: assets.py プロジェクト: kczxl/zipline
def _convert_asset_timestamp_fields(dict):
    """
    Takes in a dict of Asset init args and converts dates to pd.Timestamps
    """
    for key in (_asset_timestamp_fields & viewkeys(dict)):
        value = pd.Timestamp(dict[key], tz='UTC')
        dict[key] = None if pd.isnull(value) else value
コード例 #26
0
ファイル: preprocess.py プロジェクト: kczxl/zipline
    def _decorator(f):
        args, varargs, varkw, defaults = argspec = getargspec(f)
        if defaults is None:
            defaults = ()
        no_defaults = (NO_DEFAULT,) * (len(args) - len(defaults))
        args_defaults = zip(args, no_defaults + defaults)

        argset = set(args)

        # These assumptions simplify the implementation significantly.  If you
        # really want to validate a *args/**kwargs function, you'll have to
        # implement this here or do it yourself.
        if varargs:
            raise TypeError(
                "Can't validate functions that take *args: %s" % argspec
            )
        if varkw:
            raise TypeError(
                "Can't validate functions that take **kwargs: %s" % argspec
            )

        # Arguments can be declared as tuples in Python 2.
        if not all(isinstance(arg, str) for arg in args):
            raise TypeError(
                "Can't validate functions using tuple unpacking: %s" % argspec
            )

        # Ensure that all processors map to valid names.
        bad_names = viewkeys(processors) - argset
        if bad_names:
            raise TypeError(
                "Got processors for unknown arguments: %s." % bad_names
            )

        return _build_preprocessed_function(f, processors, args_defaults)
コード例 #27
0
    def extract(cls, data):
        """
        Extract the value for this parser's field.
        Field keys in data are matched case insensitively.
        A MetadataFieldNotFoundException is raised if none of the allowed fields are found.
        A MultipleFieldException is raised if more than one of the allowed fields are found.
        """
        availableFields = six.viewkeys(data)
        allowedFields = set(field.lower() for field in cls.allowedFields)

        foundFields = [field for field
                       in availableFields
                       if field.lower() in allowedFields]

        if not foundFields:
            raise MetadataFieldNotFoundException(fields=cls.allowedFields)
        if len(foundFields) > 1:
            raise MultipleFieldException(name=cls.name, fields=sorted(foundFields))

        field = foundFields.pop()
        value = data.pop(field)

        assert(value is None or isinstance(value, six.string_types))

        return value
コード例 #28
0
ファイル: gpdiagnostics.py プロジェクト: dkhikhlukha/pwkit
    def fill_from_dict (self, bybl):
        self.nsamps = len (bybl)

        seenants = set ()
        for a1, a2 in six.viewkeys (bybl):
            seenants.add (a1)
            seenants.add (a2)
        self.ants = np.array (sorted (seenants))
        self.nants = self.ants.size
        self.ant_to_antidx = dict ((num, idx) for idx, num in enumerate (self.ants))

        self.ncontrib = np.empty ((self.nsamps,), dtype=np.int)
        self.vis = np.empty ((self.nsamps,), dtype=np.complex)
        self.blidxs = np.empty ((self.nsamps, 2), dtype=np.int)
        self.nperant = np.zeros ((self.nants,), dtype=np.int)

        for i, (bl, (data, flags)) in enumerate (six.viewitems (bybl)):
            ok = ~flags
            self.ncontrib[i] = ok.sum ()
            self.vis[i] = data[ok].mean ()
            i1 = self.ant_to_antidx[bl[0]]
            i2 = self.ant_to_antidx[bl[1]]
            self.blidxs[i] = i1, i2
            self.nperant[i1] += 1
            self.nperant[i2] += 1
コード例 #29
0
ファイル: parser.py プロジェクト: bernii/querystring-parser
def _normalize(d):
    """
    The above parse function generates output of list in dict form
    i.e. {'abc' : {0: 'xyz', 1: 'pqr'}}. This function normalize it and turn
    them into proper data type, i.e. {'abc': ['xyz', 'pqr']}

    Note: if dict has element starts with 10, 11 etc.. this function won't fill
    blanks.
    for eg: {'abc': {10: 'xyz', 12: 'pqr'}} will convert to 
    {'abc': ['xyz', 'pqr']}
    """
    newd = {}
    if isinstance(d, dict) == False:
        return d
    # if dictionary. iterate over each element and append to newd
    for k, v in six.iteritems(d):
        if isinstance(v, dict):
            first_key = next(iter(six.viewkeys(v)))
            if isinstance(first_key, int):
                temp_new = []
                for k1, v1 in v.items():
                    temp_new.append(_normalize(v1))
                newd[k] = temp_new
            elif first_key == "":
                newd[k] = v.values()[0]
            else:
                newd[k] = _normalize(v)
        else:
            newd[k] = v
    return newd
コード例 #30
0
ファイル: scoring.py プロジェクト: girder/covalic
def computeAverageScores(score):
    """
    Compute the average score for each metric and add it to the score list
    under the name "Average".

    Datasets with a score of None are omitted from the average calculation.

    :param score: The score object to compute the average of. The result of the
        computation is placed at the head of the list.
    :type score: list
    """
    sums = defaultdict(float)
    counts = defaultdict(int)

    for dataset in score:
        for metric in dataset['metrics']:
            if metric['value'] is not None:
                sums[metric['name']] += float(metric['value'])
                counts[metric['name']] += 1

    metrics = [
        {
            'name': metricName,
            'value': sums[metricName] / float(counts[metricName])
        }
        for metricName in sorted(six.viewkeys(sums))]

    score.insert(0, {
        'dataset': 'Average',
        'metrics': metrics
    })
コード例 #31
0
ファイル: performance.py プロジェクト: arbennett/LIVVkit
def generate_timing_breakdown_plot(timing_stats, scaling_var, title, description, plot_file):
    """
    Description

    Args:
        timing_stats: a dictionary of the form
            {proc_count : {model||bench : { var : { stat : val }}}}
        scaling_var: the variable that accounts for the total runtime
        title: the title of the plot
        description: the description of the plot
        plot_file: the file to write the plot out to
    Returns:
        an image element containing the plot file and metadata
    """
    cmap_data = colormaps._viridis_data
    n_subplots = len(six.viewkeys(timing_stats))
    fig, ax = plt.subplots(1, n_subplots+1, figsize=(3*(n_subplots+2), 5))
    for plot_num, p_count in enumerate(
            sorted(six.iterkeys(timing_stats), key=functions.sort_processor_counts)):

        case_data = timing_stats[p_count]
        all_timers = set(six.iterkeys(case_data['model'])) | set(six.iterkeys(case_data['bench']))
        all_timers = sorted(list(all_timers), reverse=True)
        cmap_stride = int(len(cmap_data)/(len(all_timers)+1))
        colors = {all_timers[i]: cmap_data[i*cmap_stride] for i in range(len(all_timers))}

        sub_ax = plt.subplot(1, n_subplots+1, plot_num+1)
        sub_ax.set_title(p_count)
        sub_ax.set_ylabel('Runtime (s)')
        for case, var_data in case_data.items():
            if case == 'bench':
                bar_num = 2
            else:
                bar_num = 1

            offset = 0
            if var_data != {}:
                for var in sorted(six.iterkeys(var_data), reverse=True):
                    if var != scaling_var:
                        plt.bar(bar_num, var_data[var]['mean'], 0.8, bottom=offset,
                                color=colors[var], label=(var if bar_num == 1 else '_none'))
                        offset += var_data[var]['mean']

                plt.bar(bar_num, var_data[scaling_var]['mean']-offset, 0.8, bottom=offset,
                        color=colors[scaling_var], label=(scaling_var if bar_num == 1 else '_none'))

                sub_ax.set_xticks([1.4, 2.4])
                sub_ax.set_xticklabels(('test', 'bench'))

    plt.legend(loc=6, bbox_to_anchor=(1.05,0.5))
    plt.tight_layout()

    sub_ax = plt.subplot(1, n_subplots+1, n_subplots+1)
    hid_bar = plt.bar(1, 100)
    for group in hid_bar:
            group.set_visible(False)
    sub_ax.set_visible(False)

    if livvkit.publish:
        plt.savefig(os.path.splitext(plot_file)[0]+'.eps', dpi=600)
    plt.savefig(plot_file)
    plt.close()
    return elements.image(title, description, os.path.basename(plot_file))
コード例 #32
0
 def variables(self):
     """Return immutable view of variables in expression."""
     return viewkeys(self._variables)
コード例 #33
0
ファイル: report.py プロジェクト: yimian/spark-df-profiling
def to_html(sample, stats_object):
    """
    Generate a HTML report from summary statistics and a given sample
    :param sample: DataFrame containing the sample you want to print
    :param stats_object: Dictionary containing summary statistics. Should be generated with an appropriate describe() function
    :return: profile report in HTML format
    :type: string
    """
    n_obs = stats_object['table']['n']
    row_formatters = formatters.row_formatters

    if not isinstance(sample, pd.DataFrame):
        raise TypeError('sample must be of type pandas.DataFrame')

    if not isinstance(stats_object, dict):
        raise TypeError(
            'stats_object must be of type dict. Did you generate this using the spark_df_profiling.describe() function?'
        )

    if set(stats_object.keys()) != {'table', 'variables', 'freq'}:
        raise TypeError(
            'stats_object badly formatted. Did you generate this using the spark_df_profiling-eda.describe() function?'
        )

    # Variables
    rows_html = ''
    messages = []

    for idx, row in stats_object['variables'].iterrows():
        formatted_values = {'varname': idx, 'varid': hash(idx)}
        for col, value in six.iteritems(row):
            formatted_values[col] = value_format(value, col)

        row_classes = {}
        for col in set(row.index) & six.viewkeys(row_formatters):
            row_classes[col] = row_formatters[col](row[col])
            if row_classes[col] == 'alert' and col in templates.messages:
                messages.append(templates.messages[col].format(
                    formatted_values, varname=formatters.fmt_varname(idx)))

        if row['type'] == 'CAT':
            formatted_values['minifreqtable'] = format_freq_table(
                idx, stats_object['freq'][idx], n_obs,
                stats_object['variables'].ix[idx],
                templates.template('mini_freq_table'),
                templates.template('mini_freq_table_row'), 3)
            formatted_values['freqtable'] = format_freq_table(
                idx, stats_object['freq'][idx], n_obs,
                stats_object['variables'].ix[idx],
                templates.template('freq_table'),
                templates.template('freq_table_row'), 20)
            if row['distinct_count'] > 50:
                messages.append(templates.messages['HIGH_CARDINALITY'].format(
                    formatted_values, varname=formatters.fmt_varname(idx)))
                row_classes['distinct_count'] = 'alert'
            else:
                row_classes['distinct_count'] = ''

        if row['type'] == 'UNIQUE':
            obs = stats_object['freq'][idx].index
            formatted_values['firstn'] = pd.DataFrame(
                obs[0:3],
                columns=['First 3 values']).to_html(classes='example_values',
                                                    index=False)
            formatted_values['lastn'] = pd.DataFrame(
                obs[-3:],
                columns=['Last 3 values']).to_html(classes='example_values',
                                                   index=False)

            if n_obs > 40:
                formatted_values['firstn_expanded'] = pd.DataFrame(
                    obs[0:20], index=range(1, 21)).to_html(
                        classes='sample table table-hover', header=False)
                formatted_values['lastn_expanded'] = pd.DataFrame(
                    obs[-20:], index=range(n_obs - 20 + 1, n_obs + 1)).to_html(
                        classes='sample table table-hover', header=False)
            else:
                formatted_values['firstn_expanded'] = pd.DataFrame(
                    obs, index=range(1, n_obs + 1)).to_html(
                        classes='sample table table-hover', header=False)
                formatted_values['lastn_expanded'] = ''

        rows_html += templates.row_templates_dict[row['type']].render(
            values=formatted_values, row_classes=row_classes)

        if row['type'] in ['CORR', 'CONST']:
            formatted_values['varname'] = formatters.fmt_varname(idx)
            messages.append(
                templates.messages[row['type']].format(formatted_values))

    # Overview
    formatted_values = {
        k: value_format(v, k)
        for k, v in six.iteritems(stats_object['table'])
    }

    row_classes = {}
    for col in six.viewkeys(
            stats_object['table']) & six.viewkeys(row_formatters):
        row_classes[col] = row_formatters[col](stats_object['table'][col])
        if row_classes[col] == 'alert' and col in templates.messages:
            messages.append(templates.messages[col].format(
                formatted_values, varname=formatters.fmt_varname(idx)))

    messages_html = ''
    for msg in messages:
        messages_html += templates.message_row.format(message=msg)

    overview_html = templates.template('overview').render(
        values=formatted_values,
        row_classes=row_classes,
        messages=messages_html)

    # Add Sample
    sample_html = templates.template('sample').render(
        sample_table_html=sample.to_html(classes='sample', index=False))
    # TODO: should be done in the template
    return templates.template('base').render({
        'overview_html': overview_html,
        'rows_html': rows_html,
        'sample_html': sample_html
    })
コード例 #34
0
def visualize_boxes_and_labels_on_image_array(
    image,
    boxes,
    classes,
    scores,
    category_index,
    instance_masks=None,
    instance_boundaries=None,
    keypoints=None,
    track_ids=None,
    use_normalized_coordinates=False,
    max_boxes_to_draw=20,
    min_score_thresh=0.5,
    agnostic_mode=False,
    line_thickness=4,
    groundtruth_box_visualization_color="black",
    skip_scores=False,
    skip_labels=False,
    skip_track_ids=False,
    detection_boundary_mask=None,
    detection_box_ignored=None,
):
    """Overlay labeled boxes on an image with formatted scores and label names.

    This function groups boxes that correspond to the same location
    and creates a display string for each detection and overlays these
    on the image. Note that this function modifies the image in place, and returns
    that same image.

    Args:
      image: uint8 numpy array with shape (img_height, img_width, 3)
      boxes: a numpy array of shape [N, 4]
      classes: a numpy array of shape [N]. Note that class indices are 1-based,
        and match the keys in the label map.
      scores: a numpy array of shape [N] or None.  If scores=None, then
        this function assumes that the boxes to be plotted are groundtruth
        boxes and plot all boxes as black with no classes or scores.
      category_index: a dict containing category dictionaries (each holding
        category index `id` and category name `name`) keyed by category indices.
      instance_masks: a numpy array of shape [N, image_height, image_width] with
        values ranging between 0 and 1, can be None.
      instance_boundaries: a numpy array of shape [N, image_height, image_width]
        with values ranging between 0 and 1, can be None.
      keypoints: a numpy array of shape [N, num_keypoints, 2], can
        be None
      track_ids: a numpy array of shape [N] with unique track ids. If provided,
        color-coding of boxes will be determined by these ids, and not the class
        indices.
      use_normalized_coordinates: whether boxes is to be interpreted as
        normalized coordinates or not.
      max_boxes_to_draw: maximum number of boxes to visualize.  If None, draw
        all boxes.
      min_score_thresh: minimum score threshold for a box to be visualized
      agnostic_mode: boolean (default: False) controlling whether to evaluate in
        class-agnostic mode or not.  This mode will display scores but ignore
        classes.
      line_thickness: integer (default: 4) controlling line width of the boxes.
      groundtruth_box_visualization_color: box color for visualizing groundtruth
        boxes
      skip_scores: whether to skip score when drawing a single detection
      skip_labels: whether to skip label when drawing a single detection
      skip_track_ids: whether to skip track id when drawing a single detection
      calibration: dict of {x0,y0,x1,y1}
    Returns:
      uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes.
    """
    # Create a display string (and color) for every box location, group any boxes
    # that correspond to the same location.
    box_to_display_str_map = collections.defaultdict(list)
    box_to_color_map = collections.defaultdict(str)
    box_to_instance_masks_map = {}
    box_to_instance_boundaries_map = {}
    box_to_keypoints_map = collections.defaultdict(list)
    box_to_track_ids_map = {}
    box_to_detection_box_ignored_map = {}

    if not max_boxes_to_draw:
        max_boxes_to_draw = boxes.shape[0]
    for i in range(min(max_boxes_to_draw, boxes.shape[0])):
        if scores is None or scores[i] > min_score_thresh:
            box = tuple(boxes[i].tolist())
            if instance_masks is not None:
                box_to_instance_masks_map[box] = instance_masks[i]
            if instance_boundaries is not None:
                box_to_instance_boundaries_map[box] = instance_boundaries[i]
            if keypoints is not None:
                box_to_keypoints_map[box].extend(keypoints[i])
            if track_ids is not None:
                box_to_track_ids_map[box] = track_ids[i]
            if detection_box_ignored is not None:
                box_to_detection_box_ignored_map[box] = detection_box_ignored[
                    i]
            if scores is None:
                box_to_color_map[box] = groundtruth_box_visualization_color
            else:
                display_str = ""
                if not skip_labels:
                    if not agnostic_mode:
                        if classes[i] in six.viewkeys(category_index):
                            class_name = category_index[classes[i]]["name"]
                        else:
                            class_name = "N/A"
                        display_str = str(class_name)
                if not skip_scores:
                    if not display_str:
                        display_str = "{}%".format(int(100 * scores[i]))
                    else:
                        display_str = "{}: {}%".format(display_str,
                                                       int(100 * scores[i]))
                if not skip_track_ids and track_ids is not None:
                    if not display_str:
                        display_str = "ID {}".format(track_ids[i])
                    else:
                        display_str = "{}: ID {}".format(
                            display_str, track_ids[i])
                box_to_display_str_map[box].append(display_str)
                if agnostic_mode:
                    box_to_color_map[box] = "DarkOrange"
                elif track_ids is not None:
                    prime_multipler = _get_multiplier_for_color_randomness()
                    box_to_color_map[box] = STANDARD_COLORS[
                        (prime_multipler * track_ids[i]) %
                        len(STANDARD_COLORS)]
                else:
                    box_to_color_map[box] = STANDARD_COLORS[
                        classes[i] % len(STANDARD_COLORS)]

    if detection_boundary_mask is not None:
        detection_boundary_mask = np.logical_not(
            detection_boundary_mask).astype(np.uint8)

        draw_mask_on_image_array(image,
                                 detection_boundary_mask,
                                 color="Black",
                                 alpha=0.6)
    # Draw all boxes onto image.
    for box, color in box_to_color_map.items():
        ymin, xmin, ymax, xmax = box

        draw_bounding_box_on_image_array(
            image,
            ymin,
            xmin,
            ymax,
            xmax,
            color=color,
            thickness=line_thickness,
            display_str_list=box_to_display_str_map[box],
            use_normalized_coordinates=use_normalized_coordinates,
            detection_box_ignored=box_to_detection_box_ignored_map.get(box),
        )

    return image
コード例 #35
0
def get_boxes(image,
              boxes,
              classes,
              scores,
              category_index,
              instance_masks=None,
              instance_boundaries=None,
              keypoints=None,
              track_ids=None,
              use_normalized_coordinates=False,
              max_boxes_to_draw=20,
              min_score_thresh=.5,
              agnostic_mode=False,
              line_thickness=4,
              groundtruth_box_visualization_color='black',
              skip_scores=False,
              skip_labels=False,
              skip_track_ids=False):
    box_to_display_str_map = collections.defaultdict(list)
    box_to_color_map = collections.defaultdict(str)
    box_to_instance_masks_map = {}
    box_to_instance_boundaries_map = {}
    box_to_keypoints_map = collections.defaultdict(list)
    box_to_track_ids_map = {}
    if not max_boxes_to_draw:
        max_boxes_to_draw = boxes.shape[0]
    for i in range(min(max_boxes_to_draw, boxes.shape[0])):
        if scores is None or scores[i] > min_score_thresh:
            box = tuple(boxes[i].tolist())
            if instance_masks is not None:
                box_to_instance_masks_map[box] = instance_masks[i]
            if instance_boundaries is not None:
                box_to_instance_boundaries_map[box] = instance_boundaries[i]
            if keypoints is not None:
                box_to_keypoints_map[box].extend(keypoints[i])
            if track_ids is not None:
                box_to_track_ids_map[box] = track_ids[i]
            if scores is None:
                box_to_color_map[box] = groundtruth_box_visualization_color
            else:
                display_str = ''
                if not skip_labels:
                    if not agnostic_mode:
                        if classes[i] in six.viewkeys(category_index):
                            class_name = category_index[classes[i]]['name']
                        else:
                            class_name = 'N/A'
                        display_str = str(class_name)
                if not skip_scores:
                    if not display_str:
                        display_str = '{}%'.format(int(100 * scores[i]))
                    else:
                        display_str = '{}: {}%'.format(display_str,
                                                       int(100 * scores[i]))
                if not skip_track_ids and track_ids is not None:
                    if not display_str:
                        display_str = 'ID {}'.format(track_ids[i])
                    else:
                        display_str = '{}: ID {}'.format(
                            display_str, track_ids[i])
                box_to_display_str_map[box].append(display_str)
                if agnostic_mode:
                    box_to_color_map[box] = 'DarkOrange'
                elif track_ids is not None:
                    prime_multipler = vis_util._get_multiplier_for_color_randomness(
                    )
                    box_to_color_map[box] = vis_util.STANDARD_COLORS[
                        (prime_multipler * track_ids[i]) %
                        len(vis_util.STANDARD_COLORS)]
                else:
                    box_to_color_map[box] = vis_util.STANDARD_COLORS[
                        classes[i] % len(vis_util.STANDARD_COLORS)]

    return box_to_display_str_map
コード例 #36
0
ファイル: system_test.py プロジェクト: wphicks/girder
    def testSettings(self):
        users = self.users

        # Only admins should be able to get or set settings
        for method in ('GET', 'PUT', 'DELETE'):
            resp = self.request(path='/system/setting',
                                method=method,
                                params={
                                    'key': 'foo',
                                    'value': 'bar'
                                },
                                user=users[1])
            self.assertStatus(resp, 403)

        # Only valid setting keys should be allowed
        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={
                                'key': 'foo',
                                'value': 'bar'
                            },
                            user=users[0])
        self.assertStatus(resp, 400)
        self.assertEqual(resp.json['field'], 'key')

        # Only a valid JSON list is permitted
        resp = self.request(path='/system/setting',
                            method='GET',
                            params={'list': json.dumps('not_a_list')},
                            user=users[0])
        self.assertStatus(resp, 400)

        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={'list': json.dumps('not_a_list')},
                            user=users[0])
        self.assertStatus(resp, 400)

        # Set an invalid setting value, should fail
        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={
                                'key': SettingKey.BANNER_COLOR,
                                'value': 'bar'
                            },
                            user=users[0])
        self.assertStatus(resp, 400)
        self.assertEqual(resp.json['message'],
                         'The banner color must be a hex color triplet')

        # Set a valid value
        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={
                                'key': SettingKey.BANNER_COLOR,
                                'value': '#121212'
                            },
                            user=users[0])
        self.assertStatusOk(resp)

        # We should now be able to retrieve it
        resp = self.request(path='/system/setting',
                            method='GET',
                            params={'key': SettingKey.BANNER_COLOR},
                            user=users[0])
        self.assertStatusOk(resp)
        self.assertEqual(resp.json, '#121212')

        # We should now clear the setting
        resp = self.request(path='/system/setting',
                            method='DELETE',
                            params={'key': SettingKey.BANNER_COLOR},
                            user=users[0])
        self.assertStatusOk(resp)

        # Setting should now be default
        setting = Setting().get(SettingKey.BANNER_COLOR)
        self.assertEqual(setting,
                         SettingDefault.defaults[SettingKey.BANNER_COLOR])

        # We should also be able to put several setting using a JSON list
        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={
                                'list':
                                json.dumps([
                                    {
                                        'key': SettingKey.BANNER_COLOR,
                                        'value': '#121212'
                                    },
                                    {
                                        'key': SettingKey.COOKIE_LIFETIME,
                                        'value': None
                                    },
                                ])
                            },
                            user=users[0])
        self.assertStatusOk(resp)

        # We can get a list as well
        resp = self.request(path='/system/setting',
                            method='GET',
                            params={
                                'list':
                                json.dumps([
                                    SettingKey.BANNER_COLOR,
                                    SettingKey.COOKIE_LIFETIME,
                                ])
                            },
                            user=users[0])
        self.assertStatusOk(resp)
        self.assertEqual(resp.json[SettingKey.BANNER_COLOR], '#121212')

        # Try to set each key in turn to test the validation.  First test with
        # am invalid value, then test with the default value.  If the value
        # 'bad' won't trigger a validation error, the key should be present in
        # the badValues table.
        badValues = {
            SettingKey.BRAND_NAME: '',
            SettingKey.BANNER_COLOR: '',
            SettingKey.EMAIL_FROM_ADDRESS: '',
            SettingKey.PRIVACY_NOTICE: '',
            SettingKey.CONTACT_EMAIL_ADDRESS: '',
            SettingKey.EMAIL_HOST: {},
            SettingKey.SMTP_HOST: '',
            SettingKey.SMTP_PASSWORD: {},
            SettingKey.SMTP_USERNAME: {},
            SettingKey.CORS_ALLOW_ORIGIN: {},
            SettingKey.CORS_ALLOW_METHODS: {},
            SettingKey.CORS_ALLOW_HEADERS: {},
            SettingKey.CORS_EXPOSE_HEADERS: {},
        }
        allKeys = dict.fromkeys(six.viewkeys(SettingDefault.defaults))
        allKeys.update(badValues)
        for key in allKeys:
            resp = self.request(path='/system/setting',
                                method='PUT',
                                params={
                                    'key': key,
                                    'value': badValues.get(key, 'bad')
                                },
                                user=users[0])
            self.assertStatus(resp, 400)
            resp = self.request(path='/system/setting',
                                method='PUT',
                                params={
                                    'key':
                                    key,
                                    'value':
                                    json.dumps(
                                        SettingDefault.defaults.get(key, ''))
                                },
                                user=users[0])
            self.assertStatusOk(resp)
            resp = self.request(
                path='/system/setting',
                method='PUT',
                params={'list': json.dumps([{
                    'key': key,
                    'value': None
                }])},
                user=users[0])
            self.assertStatusOk(resp)
コード例 #37
0
    def testRegisterAndLoginBcrypt(self):
        """
        Test user registration and logging in.
        """
        cherrypy.config['auth']['hash_alg'] = 'bcrypt'
        # Set this to minimum so test runs faster.
        cherrypy.config['auth']['bcrypt_rounds'] = 4

        params = {
            'email': 'bad_email',
            'login': '******',
            'firstName': 'First',
            'lastName': 'Last',
            'password': '******'
        }
        # First test all of the required parameters.
        self.ensureRequiredParams(path='/user',
                                  method='POST',
                                  required=six.viewkeys(params))

        # Now test parameter validation
        resp = self.request(path='/user', method='POST', params=params)
        self.assertValidationError(resp, 'password')
        self.assertEqual(cherrypy.config['users']['password_description'],
                         resp.json['message'])

        params['password'] = '******'
        resp = self.request(path='/user', method='POST', params=params)
        self.assertValidationError(resp, 'login')

        # Make login something that violates the regex but doesn't contain @
        params['login'] = '******'
        resp = self.request(path='/user', method='POST', params=params)
        self.assertValidationError(resp, 'login')
        self.assertEqual(cherrypy.config['users']['login_description'],
                         resp.json['message'])

        params['login'] = '******'
        resp = self.request(path='/user', method='POST', params=params)
        self.assertValidationError(resp, 'email')

        # Now successfully create the user
        params['email'] = '*****@*****.**'
        resp = self.request(path='/user', method='POST', params=params)
        self.assertStatusOk(resp)
        self._verifyUserDocument(resp.json)

        user = User().load(resp.json['_id'], force=True)
        self.assertEqual(user['hashAlg'], 'bcrypt')

        # Try logging in without basic auth, should get 401
        resp = self.request(path='/user/authentication', method='GET')
        self.assertStatus(resp, 401)

        # Bad authentication header
        resp = self.request(path='/user/authentication',
                            method='GET',
                            additionalHeaders=[('Girder-Authorization',
                                                'Basic Not-Valid-64')])
        self.assertStatus(resp, 401)
        self.assertEqual('Invalid HTTP Authorization header',
                         resp.json['message'])
        resp = self.request(path='/user/authentication',
                            method='GET',
                            additionalHeaders=[('Girder-Authorization',
                                                'Basic NotValid')])
        self.assertStatus(resp, 401)
        self.assertEqual('Invalid HTTP Authorization header',
                         resp.json['message'])

        # Login with unregistered email
        resp = self.request(path='/user/authentication',
                            method='GET',
                            basicAuth='[email protected]:badpassword')
        self.assertStatus(resp, 401)
        self.assertEqual('Login failed.', resp.json['message'])

        # Correct email, but wrong password
        resp = self.request(path='/user/authentication',
                            method='GET',
                            basicAuth='[email protected]:badpassword')
        self.assertStatus(resp, 401)
        self.assertEqual('Login failed.', resp.json['message'])

        # Login successfully with email
        resp = self.request(path='/user/authentication',
                            method='GET',
                            basicAuth='[email protected]:good:password')
        self.assertStatusOk(resp)
        self.assertHasKeys(resp.json, ['authToken'])
        self.assertHasKeys(resp.json['authToken'], ['token', 'expires'])
        self._verifyAuthCookie(resp)

        # Invalid login
        resp = self.request(path='/user/authentication',
                            method='GET',
                            basicAuth='badlogin:good:password')
        self.assertStatus(resp, 401)
        self.assertEqual('Login failed.', resp.json['message'])

        # Login successfully with fallback Authorization header
        resp = self.request(path='/user/authentication',
                            method='GET',
                            basicAuth='goodlogin:good:password',
                            authHeader='Authorization')
        self.assertStatusOk(resp)

        # Test secure cookie validation
        with self.assertRaises(ValidationException):
            Setting().set(SettingKey.SECURE_COOKIE, 'bad value')
        # Set secure cookie value
        Setting().set(SettingKey.SECURE_COOKIE, True)

        # Login successfully with login
        resp = self.request(path='/user/authentication',
                            method='GET',
                            basicAuth='goodlogin:good:password')
        self.assertStatusOk(resp)

        # Make sure we got a nice (secure) cookie
        self._verifyAuthCookie(resp, secure=True)

        # Test user/me
        resp = self.request(path='/user/me', method='GET', user=user)
        self.assertStatusOk(resp)
        self.assertEqual(resp.json['login'], user['login'])
コード例 #38
0
    def handle_param_renaming_with_kwargs(self,
                                          new_name,
                                          old_names,
                                          context,
                                          new_name_value,
                                          new_name_is_provided,
                                          user_kwargs,
                                          transform_old=lambda x: x):
        """Look for errors with a particular parameter in an invocation

        Exemple
        -------
        >>> def fn(newname='default', **kwargs):
        ...     newname, kwargs = deprecation_pool.handle_param_renaming_with_kwargs(
        ...         new_name='newname', old_names={'oldname': '0.2.3'},
        ...         new_name_value=newname, context='the fn function',
        ...         new_name_is_provided=newname != 'default',
        ...         user_kwargs=kwargs,
        ...     )
        ...     return newname

        >>> fn() # Nothing happens
        'default'

        >>> fn(newname='default') # Nothing happens
        'default'

        >>> fn(oldname='aha') # A warning is issued the first time
        WARNING:root:`oldname` is deprecated since v0.2.3, use `newname`
        'aha'

        >>> fn(newname='default', oldname='default') # A warning is issued the first time
        WARNING:root:`oldname` is deprecated since v0.2.3, use `newname`
        'default'

        >>> fn(newname='default', oldname='aha') # A warning is issued the first time
        WARNING:root:`oldname` is deprecated since v0.2.3, use `newname`
        'aha'

        >>> fn(newname='aha', oldname='default') # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
        NameError: Using both `newname` and `oldname`, `oldname` is deprecated

        >>> fn(newname='aha', oldname='aha') # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
        NameError: Using both `newname` and `oldname`, `oldname` is deprecated

        """
        deprecated_names_used = six.viewkeys(old_names) & six.viewkeys(
            user_kwargs)
        if len(deprecated_names_used) == 0:
            return new_name_value, user_kwargs
        n = deprecated_names_used.pop()
        if new_name_is_provided:
            raise NameError(
                'Using both `{}` and `{}` in `{}`, `{}` is deprecated'.format(
                    new_name,
                    n,
                    context,
                    n,
                ))

        key = (context, new_name, n)
        if key not in self._seen:
            self._seen.add(key)
            logging.warning(
                '`{}` parameter in `{}` is deprecated since v{}, use `{}` instead'
                .format(
                    n,
                    context,
                    old_names[n],
                    new_name,
                ))
        v = transform_old(user_kwargs[n])
        del user_kwargs[n]
        return v, user_kwargs
コード例 #39
0
ファイル: report.py プロジェクト: juewe/pandas-profiling
def to_html(sample, stats_object):
    """Generate a HTML report from summary statistics and a given sample.

    Parameters
    ----------
    sample : DataFrame
        the sample you want to print
    stats_object : dict
        Summary statistics. Should be generated with an appropriate describe() function

    Returns
    -------
    str
        containing profile report in HTML format

    Notes
    -----
        * This function as to be refactored since it's huge and it contains inner functions
    """

    n_obs = stats_object['table']['n']

    value_formatters = formatters.value_formatters
    row_formatters = formatters.row_formatters

    if not isinstance(sample, pd.DataFrame):
        raise TypeError("sample must be of type pandas.DataFrame")

    if not isinstance(stats_object, dict):
        raise TypeError(
            "stats_object must be of type dict. Did you generate this using the pandas_profiling.describe() function?"
        )

    if not set({'table', 'variables', 'freq', 'correlations'}).issubset(
            set(stats_object.keys())):
        raise TypeError(
            "stats_object badly formatted. Did you generate this using the pandas_profiling.describe() function?"
        )

    def fmt(value, name):
        if pd.isnull(value):
            return ""
        if name in value_formatters:
            return value_formatters[name](value)
        elif isinstance(value, float):
            return value_formatters[formatters.DEFAULT_FLOAT_FORMATTER](value)
        else:
            try:
                return unicode(value)  # Python 2
            except NameError:
                return str(value)  # Python 3

    def _format_row(freq, label, max_freq, row_template, n, extra_class=''):
        if max_freq != 0:
            width = int(freq / max_freq * 99) + 1
        else:
            width = 1

        if width > 20:
            label_in_bar = freq
            label_after_bar = ""
        else:
            label_in_bar = "&nbsp;"
            label_after_bar = freq

        return row_template.render(label=label,
                                   width=width,
                                   count=freq,
                                   percentage='{:2.1f}'.format(freq / n * 100),
                                   extra_class=extra_class,
                                   label_in_bar=label_in_bar,
                                   label_after_bar=label_after_bar)

    def freq_table(freqtable,
                   n,
                   table_template,
                   row_template,
                   max_number_to_print,
                   nb_col=6):

        freq_rows_html = u''

        if max_number_to_print > n:
            max_number_to_print = n

        if max_number_to_print < len(freqtable):
            freq_other = sum(freqtable.iloc[max_number_to_print:])
            min_freq = freqtable.values[max_number_to_print]
        else:
            freq_other = 0
            min_freq = 0

        freq_missing = n - sum(freqtable)
        max_freq = max(freqtable.values[0], freq_other, freq_missing)

        # TODO: Correctly sort missing and other

        for label, freq in six.iteritems(
                freqtable.iloc[0:max_number_to_print]):
            freq_rows_html += _format_row(freq, label, max_freq, row_template,
                                          n)

        if freq_other > min_freq:
            freq_rows_html += _format_row(
                freq_other,
                "Other values (%s)" %
                (freqtable.count() - max_number_to_print),
                max_freq,
                row_template,
                n,
                extra_class='other')

        if freq_missing > min_freq:
            freq_rows_html += _format_row(freq_missing,
                                          "(Missing)",
                                          max_freq,
                                          row_template,
                                          n,
                                          extra_class='missing')

        return table_template.render(rows=freq_rows_html,
                                     varid=hash(idx),
                                     nb_col=nb_col)

    def extreme_obs_table(freqtable,
                          table_template,
                          row_template,
                          number_to_print,
                          n,
                          ascending=True):

        # If it's mixed between base types (str, int) convert to str. Pure "mixed" types are filtered during type discovery
        if "mixed" in freqtable.index.inferred_type:
            freqtable.index = freqtable.index.astype(str)

        sorted_freqTable = freqtable.sort_index()

        if ascending:
            obs_to_print = sorted_freqTable.iloc[:number_to_print]
        else:
            obs_to_print = sorted_freqTable.iloc[-number_to_print:]

        freq_rows_html = ''
        max_freq = max(obs_to_print.values)

        for label, freq in six.iteritems(obs_to_print):
            freq_rows_html += _format_row(freq, label, max_freq, row_template,
                                          n)

        return table_template.render(rows=freq_rows_html)

    # Variables
    rows_html = u""
    messages = []

    for idx, row in stats_object['variables'].iterrows():

        formatted_values = {'varname': idx, 'varid': hash(idx)}
        row_classes = {}

        for col, value in six.iteritems(row):
            formatted_values[col] = fmt(value, col)

        for col in set(row.index) & six.viewkeys(row_formatters):
            row_classes[col] = row_formatters[col](row[col])
            if row_classes[col] == "alert" and col in templates.messages:
                messages.append(templates.messages[col].format(
                    formatted_values, varname=idx))

        if row['type'] in {'CAT', 'BOOL'}:
            formatted_values['minifreqtable'] = freq_table(
                stats_object['freq'][idx], n_obs,
                templates.template('mini_freq_table'),
                templates.template('mini_freq_table_row'), 3,
                templates.mini_freq_table_nb_col[row['type']])

            if row['distinct_count'] > 50:
                messages.append(templates.messages['HIGH_CARDINALITY'].format(
                    formatted_values, varname=idx))
                row_classes['distinct_count'] = "alert"
            else:
                row_classes['distinct_count'] = ""

        if row['type'] == 'UNIQUE':
            obs = stats_object['freq'][idx].index

            formatted_values['firstn'] = pd.DataFrame(
                obs[0:3],
                columns=["First 3 values"]).to_html(classes="example_values",
                                                    index=False)
            formatted_values['lastn'] = pd.DataFrame(
                obs[-3:],
                columns=["Last 3 values"]).to_html(classes="example_values",
                                                   index=False)
        if row['type'] == 'UNSUPPORTED':
            formatted_values['varname'] = idx
            messages.append(
                templates.messages[row['type']].format(formatted_values))
        elif row['type'] in {'CORR', 'CONST', 'RECODED'}:
            formatted_values['varname'] = idx
            messages.append(
                templates.messages[row['type']].format(formatted_values))
        else:
            formatted_values['freqtable'] = freq_table(
                stats_object['freq'][idx], n_obs,
                templates.template('freq_table'),
                templates.template('freq_table_row'), 10)
            formatted_values['firstn_expanded'] = extreme_obs_table(
                stats_object['freq'][idx],
                templates.template('freq_table'),
                templates.template('freq_table_row'),
                5,
                n_obs,
                ascending=True)
            formatted_values['lastn_expanded'] = extreme_obs_table(
                stats_object['freq'][idx],
                templates.template('freq_table'),
                templates.template('freq_table_row'),
                5,
                n_obs,
                ascending=False)

        rows_html += templates.row_templates_dict[row['type']].render(
            values=formatted_values, row_classes=row_classes)

    # Overview
    formatted_values = {
        k: fmt(v, k)
        for k, v in six.iteritems(stats_object['table'])
    }

    row_classes = {}
    for col in six.viewkeys(
            stats_object['table']) & six.viewkeys(row_formatters):
        row_classes[col] = row_formatters[col](stats_object['table'][col])
        if row_classes[col] == "alert" and col in templates.messages:
            messages.append(templates.messages[col].format(formatted_values,
                                                           varname=idx))

    messages_html = u''
    for msg in messages:
        messages_html += templates.message_row.format(message=msg)

    overview_html = templates.template('overview').render(
        values=formatted_values,
        row_classes=row_classes,
        messages=messages_html)

    # Add plot of matrix correlation
    pearson_matrix = plot.correlation_matrix(
        stats_object['correlations']['pearson'], 'Pearson')
    spearman_matrix = plot.correlation_matrix(
        stats_object['correlations']['spearman'], 'Spearman')
    correlations_html = templates.template('correlations').render(
        values={
            'pearson_matrix': pearson_matrix,
            'spearman_matrix': spearman_matrix
        })

    # Add sample
    sample_html = templates.template('sample').render(
        sample_table_html=sample.to_html(classes="sample"))
    # TODO: should be done in the template
    return templates.template('base').render({
        'overview_html': overview_html,
        'rows_html': rows_html,
        'sample_html': sample_html,
        'correlation_html': correlations_html,
        'dataframe_name': dataframe_name,
        'statement': statement,
    })
コード例 #40
0
def to_html(sample, stats_object):
    """
    Generate a HTML report from summary statistics and a given sample
    Parameters
    ----------
    sample: DataFrame containing the sample you want to print
    stats_object: Dictionary containing summary statistics. Should be generated with an appropriate describe() function

    Returns
    -------
    str, containing profile report in HTML format
    """

    n_obs = stats_object['table']['n']

    value_formatters = formatters.value_formatters
    row_formatters = formatters.row_formatters

    if not isinstance(sample, pd.DataFrame):
        raise TypeError("sample must be of type pandas.DataFrame")

    if not isinstance(stats_object, dict):
        raise TypeError(
            "stats_object must be of type dict. Did you generate this using the spark_df_profiling.describe() function?"
        )

    if set(stats_object.keys()) != {'table', 'variables', 'freq'}:
        raise TypeError(
            "stats_object badly formatted. Did you generate this using the spark_df_profiling-eda.describe() function?"
        )

    def fmt(value, name):
        if pd.isnull(value):
            return ""
        if name in value_formatters:
            return value_formatters[name](value)
        elif isinstance(value, float):
            return value_formatters[formatters.DEFAULT_FLOAT_FORMATTER](value)
        else:
            if sys.version_info.major == 3:
                return str(value)
            else:
                return unicode(value)

    def freq_table(freqtable, n, var_table, table_template, row_template,
                   max_number_of_items_in_table):

        local_var_table = var_table.copy()
        freq_other_prefiltered = freqtable["***Other Values***"]
        freq_other_prefiltered_num = freqtable[
            "***Other Values Distinct Count***"]
        freqtable = freqtable.drop(
            ["***Other Values***", "***Other Values Distinct Count***"])

        freq_rows_html = u''

        freq_other = sum(
            freqtable[max_number_of_items_in_table:]) + freq_other_prefiltered
        freq_missing = var_table["n_missing"]
        max_freq = max(freqtable.values[0], freq_other, freq_missing)
        try:
            min_freq = freqtable.values[max_number_of_items_in_table]
        except IndexError:
            min_freq = 0

        # TODO: Correctly sort missing and other

        def format_row(freq, label, extra_class=''):
            width = int(freq / float(max_freq) * 99) + 1
            if width > 20:
                label_in_bar = freq
                label_after_bar = ""
            else:
                label_in_bar = "&nbsp;"
                label_after_bar = freq

            return row_template.render(label=label,
                                       width=width,
                                       count=freq,
                                       percentage='{:2.1f}'.format(
                                           freq / float(n) * 100),
                                       extra_class=extra_class,
                                       label_in_bar=label_in_bar,
                                       label_after_bar=label_after_bar)

        for label, freq in six.iteritems(
                freqtable[0:max_number_of_items_in_table]):
            freq_rows_html += format_row(freq, label)

        if freq_other > min_freq:
            freq_rows_html += format_row(
                freq_other,
                "Other values (%s)" %
                (freqtable.count() + freq_other_prefiltered_num -
                 max_number_of_items_in_table),
                extra_class='other')

        if freq_missing > min_freq:
            freq_rows_html += format_row(freq_missing,
                                         "(Missing)",
                                         extra_class='missing')

        return table_template.render(rows=freq_rows_html, varid=hash(idx))

    # Variables
    rows_html = u""
    messages = []

    for idx, row in stats_object['variables'].iterrows():

        formatted_values = {'varname': idx, 'varid': hash(idx)}
        row_classes = {}

        for col, value in six.iteritems(row):
            formatted_values[col] = fmt(value, col)

        for col in set(row.index) & six.viewkeys(row_formatters):
            row_classes[col] = row_formatters[col](row[col])
            if row_classes[col] == "alert" and col in templates.messages:
                messages.append(templates.messages[col].format(
                    formatted_values, varname=formatters.fmt_varname(idx)))

        if row['type'] == 'CAT':
            formatted_values['minifreqtable'] = freq_table(
                stats_object['freq'][idx], n_obs,
                stats_object['variables'].ix[idx],
                templates.template('mini_freq_table'),
                templates.template('mini_freq_table_row'), 3)
            formatted_values['freqtable'] = freq_table(
                stats_object['freq'][idx], n_obs,
                stats_object['variables'].ix[idx],
                templates.template('freq_table'),
                templates.template('freq_table_row'), 20)
            if row['distinct_count'] > 50:
                messages.append(templates.messages['HIGH_CARDINALITY'].format(
                    formatted_values, varname=formatters.fmt_varname(idx)))
                row_classes['distinct_count'] = "alert"
            else:
                row_classes['distinct_count'] = ""

        if row['type'] == 'UNIQUE':
            obs = stats_object['freq'][idx].index

            formatted_values['firstn'] = pd.DataFrame(
                obs[0:3],
                columns=["First 3 values"]).to_html(classes="example_values",
                                                    index=False)
            formatted_values['lastn'] = pd.DataFrame(
                obs[-3:],
                columns=["Last 3 values"]).to_html(classes="example_values",
                                                   index=False)

            if n_obs > 40:
                formatted_values['firstn_expanded'] = pd.DataFrame(
                    obs[0:20], index=range(1, 21)).to_html(
                        classes="sample table table-hover", header=False)
                formatted_values['lastn_expanded'] = pd.DataFrame(
                    obs[-20:], index=range(n_obs - 20 + 1, n_obs + 1)).to_html(
                        classes="sample table table-hover", header=False)
            else:
                formatted_values['firstn_expanded'] = pd.DataFrame(
                    obs, index=range(1, n_obs + 1)).to_html(
                        classes="sample table table-hover", header=False)
                formatted_values['lastn_expanded'] = ''

        rows_html += templates.row_templates_dict[row['type']].render(
            values=formatted_values, row_classes=row_classes)

        if row['type'] in {'CORR', 'CONST'}:
            formatted_values['varname'] = formatters.fmt_varname(idx)
            messages.append(
                templates.messages[row['type']].format(formatted_values))

    # Overview
    formatted_values = {
        k: fmt(v, k)
        for k, v in six.iteritems(stats_object['table'])
    }

    row_classes = {}
    for col in six.viewkeys(
            stats_object['table']) & six.viewkeys(row_formatters):
        row_classes[col] = row_formatters[col](stats_object['table'][col])
        if row_classes[col] == "alert" and col in templates.messages:
            messages.append(templates.messages[col].format(
                formatted_values, varname=formatters.fmt_varname(idx)))

    messages_html = u''
    for msg in messages:
        messages_html += templates.message_row.format(message=msg)

    overview_html = templates.template('overview').render(
        values=formatted_values,
        row_classes=row_classes,
        messages=messages_html)

    # Sample

    sample_html = templates.template('sample').render(
        sample_table_html=sample.to_html(classes="sample"))
    # TODO: should be done in the template
    return templates.template('base').render({
        'overview_html': overview_html,
        'rows_html': rows_html,
        'sample_html': sample_html
    })
コード例 #41
0
def visualize_boxes_and_labels_on_image_array(
        # 수정한 부분 1-2
        item,
        w,
        h,
        # 수정한 부분 1-2
        image,
        boxes,
        classes,
        scores,
        category_index,
        instance_masks=None,
        instance_boundaries=None,
        keypoints=None,
        track_ids=None,
        use_normalized_coordinates=False,
        max_boxes_to_draw=20,
        min_score_thresh=.5,
        agnostic_mode=False,
        line_thickness=4,
        groundtruth_box_visualization_color='black',
        skip_scores=False,
        skip_labels=False,
        skip_track_ids=False):
    """Overlay labeled boxes on an image with formatted scores and label names.

  This function groups boxes that correspond to the same location
  and creates a display string for each detection and overlays these
  on the image. Note that this function modifies the image in place, and returns
  that same image.

  Args:
    image: uint8 numpy array with shape (img_height, img_width, 3)
    boxes: a numpy array of shape [N, 4]
    classes: a numpy array of shape [N]. Note that class indices are 1-based,
      and match the keys in the label map.
    scores: a numpy array of shape [N] or None.  If scores=None, then
      this function assumes that the boxes to be plotted are groundtruth
      boxes and plot all boxes as black with no classes or scores.
    category_index: a dict containing category dictionaries (each holding
      category index `id` and category name `name`) keyed by category indices.
    instance_masks: a numpy array of shape [N, image_height, image_width] with
      values ranging between 0 and 1, can be None.
    instance_boundaries: a numpy array of shape [N, image_height, image_width]
      with values ranging between 0 and 1, can be None.
    keypoints: a numpy array of shape [N, num_keypoints, 2], can
      be None
    track_ids: a numpy array of shape [N] with unique track ids. If provided,
      color-coding of boxes will be determined by these ids, and not the class
      indices.
    use_normalized_coordinates: whether boxes is to be interpreted as
      normalized coordinates or not.
    max_boxes_to_draw: maximum number of boxes to visualize.  If None, draw
      all boxes.
    min_score_thresh: minimum score threshold for a box to be visualized
    agnostic_mode: boolean (default: False) controlling whether to evaluate in
      class-agnostic mode or not.  This mode will display scores but ignore
      classes.
    line_thickness: integer (default: 4) controlling line width of the boxes.
    groundtruth_box_visualization_color: box color for visualizing groundtruth
      boxes
    skip_scores: whether to skip score when drawing a single detection
    skip_labels: whether to skip label when drawing a single detection
    skip_track_ids: whether to skip track id when drawing a single detection

  Returns:
    uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes.
  """
    # Create a display string (and color) for every box location, group any boxes
    # that correspond to the same location.
    box_to_display_str_map = collections.defaultdict(list)
    box_to_color_map = collections.defaultdict(str)
    box_to_instance_masks_map = {}
    box_to_instance_boundaries_map = {}
    box_to_keypoints_map = collections.defaultdict(list)
    box_to_track_ids_map = {}
    # print("실행 후 w = {} , h = {}".format(w, h))
    facecut_w9 = int(round(w / 9))
    facecut_w4 = int(round(w / 4))
    facecut_w10 = int(round(w / 10))
    facecut_w3 = int(round(w / 3))

    label_str = ""
    if not max_boxes_to_draw:
        max_boxes_to_draw = boxes.shape[0]
    for i in range(min(max_boxes_to_draw, boxes.shape[0])):
        # min_score_thresh = 0.1
        if scores is None or scores[i] > min_score_thresh:
            box = tuple(boxes[i].tolist())
            if instance_masks is not None:
                box_to_instance_masks_map[box] = instance_masks[i]
            if instance_boundaries is not None:
                box_to_instance_boundaries_map[box] = instance_boundaries[i]
            if keypoints is not None:
                box_to_keypoints_map[box].extend(keypoints[i])
            if track_ids is not None:
                box_to_track_ids_map[box] = track_ids[i]
            if scores is None:
                box_to_color_map[box] = groundtruth_box_visualization_color
            else:
                display_str = ''
                if not skip_labels:
                    if not agnostic_mode:
                        if classes[i] in six.viewkeys(category_index):
                            class_name = category_index[classes[i]]['name']
                        # 클래스가 N/A라고 나오는 경우
                        else:
                            class_name = 'N/A'
                        display_str = str(class_name)
                if not skip_scores:
                    if not display_str:
                        display_str = '{}%'.format(int(100 * scores[i]))
                    else:
                        display_str = '{}: {}%'.format(display_str,
                                                       int(100 * scores[i]))
                        # 결과 퍼센트 출력
                        #print(display_str+", ")
                        label_str += '{}, {}%'.format(
                            class_name, int(100 * scores[i])) + ", "
                if not skip_track_ids and track_ids is not None:
                    if not display_str:
                        display_str = 'ID {}'.format(track_ids[i])
                    else:
                        display_str = '{}: ID {}'.format(
                            display_str, track_ids[i])
                box_to_display_str_map[box].append(display_str)
                if agnostic_mode:
                    box_to_color_map[box] = STANDARD_COLORS[i]
                    print(str(i), str(STANDARD_COLORS[i]))
                elif track_ids is not None:
                    prime_multipler = _get_multiplier_for_color_randomness()
                    box_to_color_map[box] = STANDARD_COLORS[
                        (prime_multipler * track_ids[i]) %
                        len(STANDARD_COLORS)]
                else:
                    box_to_color_map[box] = STANDARD_COLORS[
                        classes[i] % len(STANDARD_COLORS)]
    # Draw all boxes onto image.
    for box, color in box_to_color_map.items():
        ymin, xmin, ymax, xmax = box

        if instance_masks is not None:
            draw_mask_on_image_array(image,
                                     box_to_instance_masks_map[box],
                                     color=color)
        if instance_boundaries is not None:
            draw_mask_on_image_array(image,
                                     box_to_instance_boundaries_map[box],
                                     color='red',
                                     alpha=1.0)
        draw_bounding_box_on_image_array(
            image,
            ymin,
            xmin,
            ymax,
            xmax,
            color=color,
            thickness=line_thickness,
            display_str_list=box_to_display_str_map[box],
            use_normalized_coordinates=use_normalized_coordinates)

        # 수정한 부분 1-3
        if int(facecut_w9 * 2) <= int(xmin * w) and int(xmax * w) <= int(
                facecut_w9 * 7):
            image1 = image[0:h, facecut_w9 * 2:facecut_w9 * 7]
            output_path = "D:\\DFDC\\nose\\test_output\\nose\\v0.3\\total_real_test\\side_output1\\"
            cv2.imwrite(output_path + item, image1)

        # 정면이면 if문 통과
        if int(facecut_w4 * 1) <= int(xmin * w) and int(xmax * w) <= int(
                facecut_w4 * 3):
            image2 = image[0:h, facecut_w4 * 1:facecut_w4 * 3]
            output_path = "D:\\DFDC\\nose\\test_output\\nose\\v0.3\\total_real_test\\side_output2\\"
            cv2.imwrite(output_path + item, image2)

        if int(facecut_w10 * 3) <= int(xmin * w) and int(xmax * w) <= int(
                facecut_w10 * 7):
            image3 = image[0:h, facecut_w10 * 3:facecut_w10 * 7]
            output_path = "D:\\DFDC\\nose\\test_output\\nose\\v0.3\\total_real_test\\side_output3\\"
            cv2.imwrite(output_path + item, image3)

        if int(facecut_w3 * 1) <= int(xmin * w) and int(xmax * w) <= int(
                facecut_w3 * 2):
            image4 = image[0:h, facecut_w3 * 1:facecut_w3 * 2]
            output_path = "D:\\DFDC\\nose\\test_output\\nose\\v0.3\\total_real_test\\side_output4\\"
            cv2.imwrite(output_path + item, image4)

        # 수정한 부분 1-3 끝
        if keypoints is not None:
            draw_keypoints_on_image_array(
                image,
                box_to_keypoints_map[box],
                color=color,
                radius=line_thickness / 2,
                use_normalized_coordinates=use_normalized_coordinates)

    return image, label_str
コード例 #42
0
 def keys(self):
     return [key_validator(k) for k in six.viewkeys(self._json_data)]
コード例 #43
0
 def keys(self):
     return six.viewkeys(self._json_data)
コード例 #44
0
ファイル: git.py プロジェクト: LatticeEngines/rbtools
    def get_commit_history(self, revisions):
        """Return the commit history specified by the revisions.

        Args:
            revisions (dict):
                A dictionary of revisions to generate history for, as returned
                by :py:meth:`parse_revision_spec`.

        Returns:
            list of dict:
            The list of history entries, in order. The dictionaries have the
            following keys:

            ``commit_id``:
                The unique identifier of the commit.

            ``parent_id``:
                The unique identifier of the parent commit.

            ``author_name``:
                The name of the commit's author.

            ``author_email``:
                The e-mail address of the commit's author.

            ``author_date``:
                The date the commit was authored.

            ``committer_name``:
                The committer's name.

            ``committer_email``:
                The e-mail address of the committer.

            ``committer_date``:
                The date the commit was committed.

            ``commit_message``:
                The commit's message.

        Raises:
            rbtools.clients.errors.SCMError:
                The history is non-linear or there is a commit with no parents.
        """
        log_fields = {
            'commit_id': b'%H',
            'parent_id': b'%P',
            'author_name': b'%an',
            'author_email': b'%ae',
            'author_date': b'%ad',
            'committer_name': b'%cn',
            'committer_email': b'%ce',
            'committer_date': b'%cd',
            'commit_message': b'%B',
        }

        # 0x1f is the ASCII field separator. It is a non-printable character
        # that should not appear in any field in `git log`.
        log_format = b'%x1f'.join(six.itervalues(log_fields))

        log_entries = execute([
            self.git,
            b'log',
            b'-z',
            b'--reverse',
            b'--pretty=format:%s' % log_format,
            b'--date=iso8601-strict',
            b'%s..%s' % (bytes(revisions['base']), bytes(revisions['tip'])),
        ],
                              ignore_errors=True,
                              none_on_ignored_error=True,
                              results_unicode=True)

        if not log_entries:
            return None

        history = []
        field_names = six.viewkeys(log_fields)

        for log_entry in log_entries.split(self._NUL):
            fields = log_entry.split(self._FIELD_SEP)
            entry = dict(zip(field_names, fields))

            parents = entry['parent_id'].split()

            if len(parents) > 1:
                raise SCMError(
                    'The Git SCMClient only supports posting commit histories '
                    'that are entirely linear.')
            elif len(parents) == 0:
                raise SCMError(
                    'The Git SCMClient only supports posting commits that '
                    'have exactly one parent.')

            history.append(entry)

        return history
コード例 #45
0
ファイル: qplot.py プロジェクト: yejianye/plotnine
def qplot(x=None, y=None, data=None, facets=None, margins=False,
          geom='auto', xlim=None, ylim=None, log='', main=None,
          xlab=None, ylab=None, asp=None, **kwargs):
    """
    Quick plot

    Parameters
    ----------
    x: str | array_like
        x aesthetic
    y: str | array_like
        y aesthetic
    data: pandas.DataFrame
        Data frame to use (optional). If not specified,
        will create one, extracting arrays from the
        current environment.
    geom: str | list
        *geom(s)* to do the drawing. If ``auto``, defaults
        to 'point' if ``x`` and ``y`` are specified or
        'histogram' if only ``x`` is specified.
    xlim: tuple
        x-axis limits
    ylim: tuple
        y-axis limits
    log: 'x' | 'y' | 'xy'
        Which variables to log transform.
    main: str
        Plot title
    xlab: str
        x-axis label
    ylab: str
        y-axis label
    asp: str | float
        The y/x aspect ratio.
    kwargs: dict
        Arguments passed on to the geom.

    Returns
    -------
    p: ggplot
        ggplot object
    """
    # Extract all recognizable aesthetic mappings from the parameters
    # String values e.g  "I('red')", "I(4)" are not treated as mappings

    environment = EvalEnvironment.capture(1)
    aesthetics = {} if x is None else {'x': x}
    if y is not None:
        aesthetics['y'] = y

    def is_mapping(value):
        """
        Return True if value is not enclosed in I() function
        """
        with suppress(AttributeError):
            return not (value.startswith('I(') and value.endswith(')'))
        return True

    def I(value):
        return value

    I_env = EvalEnvironment([{'I': I}])

    for ae in six.viewkeys(kwargs) & all_aesthetics:
        value = kwargs[ae]
        if is_mapping(value):
            aesthetics[ae] = value
        else:
            kwargs[ae] = I_env.eval(value)

    # List of geoms
    if is_string(geom):
        geom = [geom]
    elif isinstance(geom, tuple):
        geom = list(geom)

    if data is None:
        data = pd.DataFrame()

    # Work out plot data, and modify aesthetics, if necessary
    def replace_auto(lst, str2):
        """
        Replace all occurences of 'auto' in with str2
        """
        for i, value in enumerate(lst):
            if value == 'auto':
                lst[i] = str2
        return lst

    if 'auto' in geom:
        if 'sample' in aesthetics:
            replace_auto(geom, 'qq')
        elif y is None:
            # If x is discrete we choose geom_bar &
            # geom_histogram otherwise. But we need to
            # evaluate the mapping to find out the dtype
            env = environment.with_outer_namespace(
                {'factor': pd.Categorical})

            if isinstance(aesthetics['x'], six.string_types):
                try:
                    x = env.eval(aesthetics['x'], inner_namespace=data)
                except Exception:
                    msg = "Could not evaluate aesthetic 'x={}'"
                    raise PlotnineError(msg.format(aesthetics['x']))
            elif not hasattr(aesthetics['x'], 'dtype'):
                x = np.asarray(aesthetics['x'])

            if x.dtype.kind in DISCRETE_KINDS:
                replace_auto(geom, 'bar')
            else:
                replace_auto(geom, 'histogram')

        else:
            if x is None:
                if pdtypes.is_list_like(aesthetics['y']):
                    aesthetics['x'] = range(len(aesthetics['y']))
                    xlab = 'range(len(y))'
                    ylab = 'y'
                else:
                    # We could solve the issue in layer.compute_asthetics
                    # but it is not worth the extra complexity
                    raise PlotnineError(
                        "Cannot infer how long x should be.")
            replace_auto(geom, 'point')

    p = ggplot(aes(**aesthetics), data=data, environment=environment)

    def get_facet_type(facets):
        with suppress(PlotnineError):
            parse_grid_facets(facets)
            return 'grid'

        with suppress(PlotnineError):
            parse_wrap_facets(facets)
            return 'wrap'

        warn("Could not determine the type of faceting, "
             "therefore no faceting.")
        return 'null'

    if facets:
        facet_type = get_facet_type(facets)
        if facet_type == 'grid':
            p += facet_grid(facets, margins=margins)
        elif facet_type == 'wrap':
            p += facet_wrap(facets)
        else:
            p += facet_null()

    # Add geoms
    for g in geom:
        geom_name = 'geom_{}'.format(g)
        geom_klass = Registry[geom_name]
        stat_name = 'stat_{}'.format(geom_klass.DEFAULT_PARAMS['stat'])
        stat_klass = Registry[stat_name]
        # find params
        recognized = (six.viewkeys(kwargs) &
                      (six.viewkeys(geom_klass.DEFAULT_PARAMS) |
                       geom_klass.aesthetics() |
                       six.viewkeys(stat_klass.DEFAULT_PARAMS) |
                       stat_klass.aesthetics()))
        recognized = recognized - six.viewkeys(aesthetics)
        params = {ae: kwargs[ae] for ae in recognized}
        p += geom_klass(**params)

    if 'x' in log:
        p += scale_x_log10()

    if 'y' in log:
        p += scale_y_log10()

    if xlab:
        p += xlabel(xlab)

    if ylab:
        p += ylabel(ylab)

    if main:
        p += ggtitle(main)

    if asp:
        p += theme(aspect_ratio=asp)

    return p
コード例 #46
0
    def update_deleted(self, job):
        """
        High-level method that replicates a single partition that doesn't
        belong on this node.

        :param job: a dict containing info about the partition to be replicated
        """
        def tpool_get_suffixes(path):
            return [
                suff for suff in os.listdir(path)
                if len(suff) == 3 and isdir(join(path, suff))
            ]

        self.replication_count += 1
        self.logger.increment('partition.delete.count.%s' % (job['device'], ))
        headers = dict(self.default_headers)
        headers['X-Backend-Storage-Policy-Index'] = int(job['policy'])
        failure_devs_info = set()
        begin = time.time()
        handoff_partition_deleted = False
        try:
            responses = []
            suffixes = tpool.execute(tpool_get_suffixes, job['path'])
            synced_remote_regions = {}
            delete_objs = None
            if suffixes:
                for node in job['nodes']:
                    self.stats['rsync'] += 1
                    kwargs = {}
                    if node['region'] in synced_remote_regions and \
                            self.conf.get('sync_method', 'rsync') == 'ssync':
                        kwargs['remote_check_objs'] = \
                            synced_remote_regions[node['region']]
                    # candidates is a dict(hash=>timestamp) of objects
                    # for deletion
                    success, candidates = self.sync(node, job, suffixes,
                                                    **kwargs)
                    if success:
                        with Timeout(self.http_timeout):
                            conn = http_connect(node['replication_ip'],
                                                node['replication_port'],
                                                node['device'],
                                                job['partition'],
                                                'REPLICATE',
                                                '/' + '-'.join(suffixes),
                                                headers=headers)
                            conn.getresponse().read()
                        if node['region'] != job['region']:
                            synced_remote_regions[node['region']] = viewkeys(
                                candidates)
                    else:
                        failure_devs_info.add(
                            (node['replication_ip'], node['device']))
                    responses.append(success)
                for cand_objs in synced_remote_regions.values():
                    if delete_objs is None:
                        delete_objs = cand_objs
                    else:
                        delete_objs = delete_objs & cand_objs

            if self.handoff_delete:
                # delete handoff if we have had handoff_delete successes
                delete_handoff = len([resp for resp in responses if resp]) >= \
                    self.handoff_delete
            else:
                # delete handoff if all syncs were successful
                delete_handoff = len(responses) == len(job['nodes']) and \
                    all(responses)
            if delete_handoff:
                self.stats['remove'] += 1
                if (self.conf.get('sync_method', 'rsync') == 'ssync'
                        and delete_objs is not None):
                    self.logger.info(_("Removing %s objects"),
                                     len(delete_objs))
                    _junk, error_paths = self.delete_handoff_objs(
                        job, delete_objs)
                    # if replication works for a hand-off device and it failed,
                    # the remote devices which are target of the replication
                    # from the hand-off device will be marked. Because cleanup
                    # after replication failed means replicator needs to
                    # replicate again with the same info.
                    if error_paths:
                        failure_devs_info.update([
                            (failure_dev['replication_ip'],
                             failure_dev['device'])
                            for failure_dev in job['nodes']
                        ])
                else:
                    self.delete_partition(job['path'])
                    handoff_partition_deleted = True
            elif not suffixes:
                self.delete_partition(job['path'])
                handoff_partition_deleted = True
        except (Exception, Timeout):
            self.logger.exception(_("Error syncing handoff partition"))
            self._add_failure_stats(failure_devs_info)
        finally:
            target_devs_info = set([(target_dev['replication_ip'],
                                     target_dev['device'])
                                    for target_dev in job['nodes']])
            self.stats['success'] += len(target_devs_info - failure_devs_info)
            if not handoff_partition_deleted:
                self.handoffs_remaining += 1
            self.partition_times.append(time.time() - begin)
            self.logger.timing_since('partition.delete.timing', begin)
コード例 #47
0
def _scan_directory(directory, output_format, ssd):
    """Scan a directory for log files and write the final model to stdout."""
    if output_format == _OUTPUT_FORMAT_LINES:
        print('directory =', directory)

    model_spec_filename = os.path.join(directory, 'model_spec.json')
    if not tf.io.gfile.exists(model_spec_filename):
        print('file {} not found; skipping'.format(model_spec_filename))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    with tf.io.gfile.GFile(model_spec_filename, 'r') as handle:
        model_spec = schema_io.deserialize(handle.read())

    paths = []
    oneofs = dict()

    def populate_oneofs(path, oneof):
        paths.append(path)
        oneofs[path] = oneof

    schema.map_oneofs_with_paths(populate_oneofs, model_spec)

    all_path_logits = analyze_mobile_search_lib.read_path_logits(directory)
    if not all_path_logits:
        print(
            'event data missing from directory {}; skipping'.format(directory))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    global_step = max(all_path_logits)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('global_step = {:d}'.format(global_step))

    all_path_logit_keys = six.viewkeys(all_path_logits[global_step])
    oneof_keys = six.viewkeys(oneofs)
    if all_path_logit_keys != oneof_keys:
        raise ValueError(
            'OneOf key mismatch. Present in event files but not in model_spec: {}. '
            'Present in model_spec but not in event files: {}'.format(
                all_path_logit_keys - oneof_keys,
                oneof_keys - all_path_logit_keys))

    indices = []
    for path in paths:
        index = np.argmax(all_path_logits[global_step][path])
        indices.append(index)

    indices_str = ':'.join(map(str, indices))
    if output_format == _OUTPUT_FORMAT_LINES:
        print('indices = {:s}'.format(indices_str))

    cost_model_time = mobile_cost_model.estimate_cost(indices, ssd)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('cost_model = {:f}'.format(cost_model_time))

    if output_format == _OUTPUT_FORMAT_LINES:
        print()
    elif output_format == _OUTPUT_FORMAT_CSV:
        fields = [indices_str, global_step, directory, cost_model_time]
        print(','.join(map(str, fields)))
コード例 #48
0
def visualize_boxes_and_labels_on_image_array(
        image,
        boxes,
        classes,
        scores,
        category_index,
        instance_masks=None,
        instance_boundaries=None,
        keypoints=None,
        keypoint_scores=None,
        keypoint_edges=None,
        track_ids=None,
        use_normalized_coordinates=False,
        max_boxes_to_draw=20,
        min_score_thresh=.5,
        agnostic_mode=False,
        line_thickness=4,
        mask_alpha=.4,
        groundtruth_box_visualization_color='black',
        skip_boxes=False,
        skip_scores=False,
        skip_labels=False,
        skip_track_ids=False):
    """Overlay labeled boxes on an image with formatted scores and label names.

  This function groups boxes that correspond to the same location
  and creates a display string for each detection and overlays these
  on the image. Note that this function modifies the image in place, and returns
  that same image.

  Args:
    image: uint8 numpy array with shape (img_height, img_width, 3)
    boxes: a numpy array of shape [N, 4]
    classes: a numpy array of shape [N]. Note that class indices are 1-based,
      and match the keys in the label map.
    scores: a numpy array of shape [N] or None.  If scores=None, then
      this function assumes that the boxes to be plotted are groundtruth
      boxes and plot all boxes as black with no classes or scores.
    category_index: a dict containing category dictionaries (each holding
      category index `id` and category name `name`) keyed by category indices.
    instance_masks: a uint8 numpy array of shape [N, image_height, image_width],
      can be None.
    instance_boundaries: a numpy array of shape [N, image_height, image_width]
      with values ranging between 0 and 1, can be None.
    keypoints: a numpy array of shape [N, num_keypoints, 2], can
      be None.
    keypoint_scores: a numpy array of shape [N, num_keypoints], can be None.
    keypoint_edges: A list of tuples with keypoint indices that specify which
      keypoints should be connected by an edge, e.g. [(0, 1), (2, 4)] draws
      edges from keypoint 0 to 1 and from keypoint 2 to 4.
    track_ids: a numpy array of shape [N] with unique track ids. If provided,
      color-coding of boxes will be determined by these ids, and not the class
      indices.
    use_normalized_coordinates: whether boxes is to be interpreted as
      normalized coordinates or not.
    max_boxes_to_draw: maximum number of boxes to visualize.  If None, draw
      all boxes.
    min_score_thresh: minimum score threshold for a box or keypoint to be
      visualized.
    agnostic_mode: boolean (default: False) controlling whether to evaluate in
      class-agnostic mode or not.  This mode will display scores but ignore
      classes.
    line_thickness: integer (default: 4) controlling line width of the boxes.
    mask_alpha: transparency value between 0 and 1 (default: 0.4).
    groundtruth_box_visualization_color: box color for visualizing groundtruth
      boxes
    skip_boxes: whether to skip the drawing of bounding boxes.
    skip_scores: whether to skip score when drawing a single detection
    skip_labels: whether to skip label when drawing a single detection
    skip_track_ids: whether to skip track id when drawing a single detection

  Returns:
    uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes.
  """
    # Create a display string (and color) for every box location, group any boxes
    # that correspond to the same location.
    box_to_display_str_map = collections.defaultdict(list)
    box_to_color_map = collections.defaultdict(str)
    box_to_instance_masks_map = {}
    box_to_instance_boundaries_map = {}
    box_to_keypoints_map = collections.defaultdict(list)
    box_to_keypoint_scores_map = collections.defaultdict(list)
    box_to_track_ids_map = {}
    if not max_boxes_to_draw:
        max_boxes_to_draw = boxes.shape[0]
    for i in range(boxes.shape[0]):
        if max_boxes_to_draw == len(box_to_color_map):
            break
        if scores is None or scores[i] > min_score_thresh:
            box = tuple(boxes[i].tolist())
            if instance_masks is not None:
                box_to_instance_masks_map[box] = instance_masks[i]
            if instance_boundaries is not None:
                box_to_instance_boundaries_map[box] = instance_boundaries[i]
            if keypoints is not None:
                box_to_keypoints_map[box].extend(keypoints[i])
            if keypoint_scores is not None:
                box_to_keypoint_scores_map[box].extend(keypoint_scores[i])
            if track_ids is not None:
                box_to_track_ids_map[box] = track_ids[i]
            if scores is None:
                box_to_color_map[box] = groundtruth_box_visualization_color
            else:
                display_str = ''
                if not skip_labels:
                    if not agnostic_mode:
                        if classes[i] in six.viewkeys(category_index):
                            class_name = category_index[classes[i]]['name']
                        else:
                            class_name = 'N/A'
                        display_str = str(class_name)
                if not skip_scores:
                    if not display_str:
                        display_str = '{}%'.format(round(100 * scores[i]))
                    else:
                        display_str = '{}: {}%'.format(display_str,
                                                       round(100 * scores[i]))
                if not skip_track_ids and track_ids is not None:
                    if not display_str:
                        display_str = 'ID {}'.format(track_ids[i])
                    else:
                        display_str = '{}: ID {}'.format(
                            display_str, track_ids[i])
                box_to_display_str_map[box].append(display_str)
                if agnostic_mode:
                    box_to_color_map[box] = 'DarkOrange'
                elif track_ids is not None:
                    prime_multipler = _get_multiplier_for_color_randomness()
                    box_to_color_map[box] = STANDARD_COLORS[
                        (prime_multipler * track_ids[i]) %
                        len(STANDARD_COLORS)]
                else:
                    box_to_color_map[box] = STANDARD_COLORS[
                        classes[i] % len(STANDARD_COLORS)]

    # Draw all boxes onto image.
    for box, color in box_to_color_map.items():
        ymin, xmin, ymax, xmax = box
        if instance_masks is not None:
            draw_mask_on_image_array(image,
                                     box_to_instance_masks_map[box],
                                     color=color,
                                     alpha=mask_alpha)
        if instance_boundaries is not None:
            draw_mask_on_image_array(image,
                                     box_to_instance_boundaries_map[box],
                                     color='red',
                                     alpha=1.0)
        draw_bounding_box_on_image_array(
            image,
            ymin,
            xmin,
            ymax,
            xmax,
            color=color,
            thickness=0 if skip_boxes else line_thickness,
            display_str_list=box_to_display_str_map[box],
            use_normalized_coordinates=use_normalized_coordinates)
        if keypoints is not None:
            keypoint_scores_for_box = None
            if box_to_keypoint_scores_map:
                keypoint_scores_for_box = box_to_keypoint_scores_map[box]
            draw_keypoints_on_image_array(
                image,
                box_to_keypoints_map[box],
                keypoint_scores_for_box,
                min_score_thresh=min_score_thresh,
                color=color,
                radius=line_thickness / 2,
                use_normalized_coordinates=use_normalized_coordinates,
                keypoint_edges=keypoint_edges,
                keypoint_edge_color=color,
                keypoint_edge_width=line_thickness // 2)

    return image
コード例 #49
0
def visualize_boxes_and_labels_on_image_array(
        image,
        boxes,
        classes,
        scores,
        category_index,
        instance_boundaries=None,
        track_ids=None,
        use_normalized_coordinates=False,
        max_boxes_to_draw=20,
        min_score_thresh=.5,
        agnostic_mode=False,
        line_thickness=4,
        groundtruth_box_visualization_color='black',
        skip_scores=False,
        skip_labels=False,
        skip_track_ids=False):
    """Overlay labeled boxes on an image with formatted scores and label names.

    This function groups boxes that correspond to the same location
    and creates a display string for each detection and overlays these
    on the image. Note that this function modifies the image in place, and returns
    that same image.

    Args:
      image: uint8 numpy array with shape (img_height, img_width, 3)
      boxes: a numpy array of shape [N, 4]
      classes: a numpy array of shape [N]. Note that class indices are 1-based,
        and match the keys in the label map.
      scores: a numpy array of shape [N] or None.  If scores=None, then
        this function assumes that the boxes to be plotted are groundtruth
        boxes and plot all boxes as black with no classes or scores.
      category_index: a dict containing category dictionaries (each holding
        category index `id` and category name `name`) keyed by category indices.
      instance_masks: a numpy array of shape [N, image_height, image_width] with
        values ranging between 0 and 1, can be None.
      instance_boundaries: a numpy array of shape [N, image_height, image_width]
        with values ranging between 0 and 1, can be None.
      keypoints: a numpy array of shape [N, num_keypoints, 2], can
        be None
      track_ids: a numpy array of shape [N] with unique track ids. If provided,
        color-coding of boxes will be determined by these ids, and not the class
        indices.
      use_normalized_coordinates: whether boxes is to be interpreted as
        normalized coordinates or not.
      max_boxes_to_draw: maximum number of boxes to visualize.  If None, draw
        all boxes.
      min_score_thresh: minimum score threshold for a box to be visualized
      agnostic_mode: boolean (default: False) controlling whether to evaluate in
        class-agnostic mode or not.  This mode will display scores but ignore
        classes.
      line_thickness: integer (default: 4) controlling line width of the boxes.
      groundtruth_box_visualization_color: box color for visualizing groundtruth
        boxes
      skip_scores: whether to skip score when drawing a single detection
      skip_labels: whether to skip label when drawing a single detection
      skip_track_ids: whether to skip track id when drawing a single detection

    Returns:
      uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes.
    """
    # Create a display string (and color) for every box location, group any boxes
    # that correspond to the same location.
    if len(boxes) < 1:
        print("No boxes")
        return image

    box_to_display_str_map = collections.defaultdict(list)
    box_to_color_map = collections.defaultdict(str)
    box_to_instance_boundaries_map = {}
    box_to_track_ids_map = {}
    if not max_boxes_to_draw:
        max_boxes_to_draw = boxes.shape[0]
    for i in range(min(max_boxes_to_draw, boxes.shape[0])):
        if scores is None or scores[i] > min_score_thresh:
            box = tuple(boxes[i].tolist())
            if instance_boundaries is not None:
                box_to_instance_boundaries_map[box] = instance_boundaries[i]
            if track_ids is not None:
                box_to_track_ids_map[box] = track_ids[i]
            if scores is None:
                box_to_color_map[box] = groundtruth_box_visualization_color
            else:
                display_str = ''
                if not skip_labels:
                    if not agnostic_mode:
                        if classes[i] in six.viewkeys(category_index):
                            class_name = category_index[classes[i]]['name']
                        else:
                            class_name = 'N/A'
                        display_str = str(class_name)
                if not skip_scores:
                    if not display_str:
                        display_str = '{}%'.format(int(100 * scores[i]))
                    else:
                        display_str = '{}: {}%'.format(display_str,
                                                       int(100 * scores[i]))
                if not skip_track_ids and track_ids is not None:
                    if not display_str:
                        display_str = 'ID {}'.format(track_ids[i])
                    else:
                        display_str = '{}: ID {}'.format(
                            display_str, track_ids[i])
                box_to_display_str_map[box].append(display_str)
                if agnostic_mode:
                    box_to_color_map[box] = 'DarkOrange'
                elif track_ids is not None:
                    prime_multipler = _get_multiplier_for_color_randomness()
                    box_to_color_map[box] = STANDARD_COLORS[
                        (prime_multipler * track_ids[i]) %
                        len(STANDARD_COLORS)]
                else:
                    box_to_color_map[box] = STANDARD_COLORS[
                        classes[i] % len(STANDARD_COLORS)]

    # Draw all boxes onto image.
    for box, color in box_to_color_map.items():
        #ymin, xmin, ymax, xmax = box
        #the boxes returned by the TF OD API's pretrained models returns boxes
        # in the format (ymin, xmin, ymax, xmax), normalized to [0, 1]. Thus, this
        # function needs to convert them to the format expected by WOD, namely
        # (center_x, center_y, width, height)
        center_x, center_y, width, height = box
        xmin = center_x - width / 2
        xmax = center_x + width / 2
        ymin = center_y - height / 2
        ymax = center_y + height / 2
        draw_bounding_box_on_image_array(
            image,
            ymin,
            xmin,
            ymax,
            xmax,
            color=color,
            thickness=line_thickness,
            display_str_list=box_to_display_str_map[box],
            use_normalized_coordinates=use_normalized_coordinates)
    return image
コード例 #50
0
    def _promote_kwargs(cls, kwargs, optional=None, read_only=None):
        if optional is None:
            optional = set()
        if not isinstance(optional, set):
            raise TypeError(
                "Optional kwargs must be given as a set, not {}".format(optional)
            )
        if read_only is None:
            read_only = set()
        if not isinstance(read_only, set):
            raise TypeError(
                "Read only kwargs must be given as a set, not {}".format(read_only)
            )
        class_name = cls.__name__
        try:
            type_params = cls._type_params[0]
        except TypeError:
            raise TypeError(
                "Cannot instantiate a generic {}; the item types must be specified".format(
                    class_name
                )
            )

        missing_args = (
            six.viewkeys(type_params) - six.viewkeys(kwargs) - optional - read_only
        )

        if len(missing_args) > 0:
            raise ProxyTypeError(
                "Missing required keyword arguments to {}: {}".format(
                    class_name, ", ".join(six.moves.map(repr, missing_args))
                )
            )

        provided_read_only_args = six.viewkeys(kwargs) & read_only
        if len(provided_read_only_args) > 0:
            raise ProxyTypeError(
                "Read only keyword argument to {}: {}".format(
                    class_name, ", ".join(six.moves.map(repr, provided_read_only_args))
                )
            )

        promoted_kwargs = {}
        for field_name, val in six.iteritems(kwargs):
            try:
                field_cls = type_params[field_name]
            except KeyError:
                raise ProxyTypeError(
                    "{} has no field {!r}".format(class_name, field_name)
                )

            if val is None or isinstance(val, NoneType) and field_name in optional:
                continue
            try:
                promoted_val = field_cls._promote(val)
            except ProxyTypeError as e:
                raise ProxyTypeError(
                    "In field {!r} of {}, expected {}, but got {}: {}".format(
                        field_name, class_name, field_cls, type(val), val
                    ),
                    e,
                )
            promoted_kwargs[field_name] = promoted_val

        return promoted_kwargs
コード例 #51
0
ファイル: manifest.py プロジェクト: veeg/servo
    def update(self, tree):
        new_data = defaultdict(dict)
        new_hashes = {}

        reftest_nodes = []
        old_files = defaultdict(
            set, {k: set(viewkeys(v))
                  for k, v in iteritems(self._data)})

        changed = False
        reftest_changes = False

        for source_file in tree:
            rel_path = source_file.rel_path
            file_hash = source_file.hash

            is_new = rel_path not in self._path_hash
            hash_changed = False

            if not is_new:
                old_hash, old_type = self._path_hash[rel_path]
                old_files[old_type].remove(rel_path)
                if old_hash != file_hash:
                    new_type, manifest_items = source_file.manifest_items()
                    hash_changed = True
                else:
                    new_type, manifest_items = old_type, self._data[old_type][
                        rel_path]
                if old_type == "reftest" and new_type != old_type:
                    reftest_changes = True
            else:
                new_type, manifest_items = source_file.manifest_items()

            if new_type in ("reftest", "reftest_node"):
                reftest_nodes.extend(manifest_items)
                if is_new or hash_changed:
                    reftest_changes = True
            elif new_type:
                new_data[new_type][rel_path] = set(manifest_items)

            new_hashes[rel_path] = (file_hash, new_type)

            if is_new or hash_changed:
                changed = True

        if reftest_changes or old_files["reftest"] or old_files["reftest_node"]:
            reftests, reftest_nodes, changed_hashes = self._compute_reftests(
                reftest_nodes)
            new_data["reftest"] = reftests
            new_data["reftest_node"] = reftest_nodes
            new_hashes.update(changed_hashes)
        else:
            new_data["reftest"] = self._data["reftest"]
            new_data["reftest_node"] = self._data["reftest_node"]

        if any(itervalues(old_files)):
            changed = True

        self._data = new_data
        self._path_hash = new_hashes

        return changed
コード例 #52
0
ファイル: function.py プロジェクト: rxy007/pytype
 def has_param_annotations(self):
     return bool(six.viewkeys(self.annotations) - {"return"})
コード例 #53
0
    def testSettings(self):
        users = self.users

        # Only admins should be able to get or set settings
        for method in ('GET', 'PUT', 'DELETE'):
            resp = self.request(path='/system/setting',
                                method=method,
                                params={
                                    'key': 'foo',
                                    'value': 'bar'
                                },
                                user=users[1])
            self.assertStatus(resp, 403)

        # Only valid setting keys should be allowed
        obj = ['oauth', 'geospatial', '_invalid_']
        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={
                                'key': 'foo',
                                'value': json.dumps(obj)
                            },
                            user=users[0])
        self.assertStatus(resp, 400)
        self.assertEqual(resp.json['field'], 'key')

        # Only a valid JSON list is permitted
        resp = self.request(path='/system/setting',
                            method='GET',
                            params={'list': json.dumps('not_a_list')},
                            user=users[0])
        self.assertStatus(resp, 400)

        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={'list': json.dumps('not_a_list')},
                            user=users[0])
        self.assertStatus(resp, 400)

        # Set an invalid setting value, should fail
        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={
                                'key': SettingKey.PLUGINS_ENABLED,
                                'value': json.dumps(obj)
                            },
                            user=users[0])
        self.assertStatus(resp, 400)
        self.assertEqual(resp.json['message'],
                         'Required plugin _invalid_ does not exist.')

        # Set a valid value
        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={
                                'key': SettingKey.PLUGINS_ENABLED,
                                'value': json.dumps(['geospatial', 'oauth'])
                            },
                            user=users[0])
        self.assertStatusOk(resp)

        # We should now be able to retrieve it
        resp = self.request(path='/system/setting',
                            method='GET',
                            params={'key': SettingKey.PLUGINS_ENABLED},
                            user=users[0])
        self.assertStatusOk(resp)
        self.assertEqual(set(resp.json), set(['geospatial', 'oauth']))

        # We should now clear the setting
        resp = self.request(path='/system/setting',
                            method='DELETE',
                            params={'key': SettingKey.PLUGINS_ENABLED},
                            user=users[0])
        self.assertStatusOk(resp)

        # Setting should now be ()
        setting = Setting().get(SettingKey.PLUGINS_ENABLED)
        self.assertEqual(setting, [])

        # We should be able to ask for a different default
        setting = Setting().get(SettingKey.PLUGINS_ENABLED, default=None)
        self.assertEqual(setting, None)

        # We should also be able to put several setting using a JSON list
        resp = self.request(path='/system/setting',
                            method='PUT',
                            params={
                                'list':
                                json.dumps([
                                    {
                                        'key': SettingKey.PLUGINS_ENABLED,
                                        'value': json.dumps(())
                                    },
                                    {
                                        'key': SettingKey.COOKIE_LIFETIME,
                                        'value': None
                                    },
                                ])
                            },
                            user=users[0])
        self.assertStatusOk(resp)

        # We can get a list as well
        resp = self.request(path='/system/setting',
                            method='GET',
                            params={
                                'list':
                                json.dumps([
                                    SettingKey.PLUGINS_ENABLED,
                                    SettingKey.COOKIE_LIFETIME,
                                ])
                            },
                            user=users[0])
        self.assertStatusOk(resp)
        self.assertEqual(resp.json[SettingKey.PLUGINS_ENABLED], [])

        # We can get the default values, or ask for no value if the current
        # value is taken from the default
        resp = self.request(path='/system/setting',
                            method='GET',
                            params={
                                'key': SettingKey.PLUGINS_ENABLED,
                                'default': 'default'
                            },
                            user=users[0])
        self.assertStatusOk(resp)
        self.assertEqual(resp.json, [])

        resp = self.request(path='/system/setting',
                            method='GET',
                            params={
                                'key': SettingKey.COOKIE_LIFETIME,
                                'default': 'none'
                            },
                            user=users[0])
        self.assertStatusOk(resp)
        self.assertEqual(resp.json, None)

        # But we have to ask for a sensible value in the default parameter
        resp = self.request(path='/system/setting',
                            method='GET',
                            params={
                                'key': SettingKey.COOKIE_LIFETIME,
                                'default': 'bad_value'
                            },
                            user=users[0])
        self.assertStatus(resp, 400)

        # Try to set each key in turn to test the validation.  First test with
        # am invalid value, then test with the default value.  If the value
        # 'bad' won't trigger a validation error, the key should be present in
        # the badValues table.
        badValues = {
            SettingKey.BRAND_NAME: '',
            SettingKey.BANNER_COLOR: '',
            SettingKey.EMAIL_FROM_ADDRESS: '',
            SettingKey.CONTACT_EMAIL_ADDRESS: '',
            SettingKey.EMAIL_HOST: {},
            SettingKey.SMTP_HOST: '',
            SettingKey.CORS_ALLOW_ORIGIN: {},
            SettingKey.CORS_ALLOW_METHODS: {},
            SettingKey.CORS_ALLOW_HEADERS: {},
        }
        allKeys = dict.fromkeys(six.viewkeys(SettingDefault.defaults))
        allKeys.update(badValues)
        for key in allKeys:
            resp = self.request(path='/system/setting',
                                method='PUT',
                                params={
                                    'key': key,
                                    'value': badValues.get(key, 'bad')
                                },
                                user=users[0])
            self.assertStatus(resp, 400)
            resp = self.request(path='/system/setting',
                                method='PUT',
                                params={
                                    'key':
                                    key,
                                    'value':
                                    json.dumps(
                                        SettingDefault.defaults.get(key, ''))
                                },
                                user=users[0])
            self.assertStatusOk(resp)
            resp = self.request(
                path='/system/setting',
                method='PUT',
                params={'list': json.dumps([{
                    'key': key,
                    'value': None
                }])},
                user=users[0])
            self.assertStatusOk(resp)
            resp = self.request(path='/system/setting',
                                method='GET',
                                params={
                                    'key': key,
                                    'default': 'default'
                                },
                                user=users[0])
            self.assertStatusOk(resp)
コード例 #54
0
def validate_schema_and_query_ast(schema, query_ast):
    """Validate the supplied GraphQL schema and query_ast.

    This method wraps around graphql-core's validation to enforce a stricter requirement of the
    schema -- all directives supported by the compiler must be declared by the schema, regardless of
    whether each directive is used in the query or not.

    Args:
        schema: GraphQL schema object, created using the GraphQL library
        query_ast: abstract syntax tree representation of a GraphQL query

    Returns:
        list containing schema and/or query validation errors
    """
    core_graphql_errors = validate(schema, query_ast)

    # The following directives appear in the core-graphql library, but are not supported by the
    # GraphQL compiler.
    unsupported_default_directives = frozenset([
        frozenset([
            "include",
            frozenset(["FIELD", "FRAGMENT_SPREAD", "INLINE_FRAGMENT"]),
            frozenset(["if"]),
        ]),
        frozenset([
            "skip",
            frozenset(["FIELD", "FRAGMENT_SPREAD", "INLINE_FRAGMENT"]),
            frozenset(["if"]),
        ]),
    ])

    # This directive is supported and ignored by the compiler, since it is meant as an indication
    # to the user that a field should not be used.
    supported_default_directive = frozenset([
        frozenset([
            "deprecated",
            frozenset(["FIELD_DEFINITION", "ENUM_VALUE"]),
            frozenset(["reason"])
        ])
    ])

    # Directives expected by the graphql compiler.
    expected_directives = {
        frozenset([
            directive.name,
            frozenset(directive.locations),
            frozenset(six.viewkeys(directive.args)),
        ])
        for directive in DIRECTIVES
    }

    # Directives provided in the parsed graphql schema.
    actual_directives = {
        frozenset([
            directive.name,
            frozenset(directive.locations),
            frozenset(six.viewkeys(directive.args)),
        ])
        for directive in schema.get_directives()
    }

    # Directives missing from the actual directives provided.
    missing_directives = expected_directives - actual_directives
    if missing_directives:
        missing_message = (u"The following directives were missing from the "
                           u"provided schema: {}".format(missing_directives))
        core_graphql_errors.append(missing_message)

    # Directives that are not specified by the core graphql library. Note that Graphql-core
    # automatically injects default directives into the schema, regardless of whether
    # the schema supports said directives. Hence, while the directives contained in
    # unsupported_default_directives are incompatible with the graphql-compiler, we allow them to
    # be present in the parsed schema string.
    extra_directives = (actual_directives - expected_directives -
                        unsupported_default_directives -
                        supported_default_directive)
    if extra_directives:
        extra_message = (
            u"The following directives were supplied in the given schema, but are not "
            u"not supported by the GraphQL compiler: {}".format(
                extra_directives))
        core_graphql_errors.append(extra_message)

    return core_graphql_errors
コード例 #55
0
def independent_sample(structure,
                       increase_ops_probability=None,
                       increase_filters_probability=None,
                       hierarchical=True,
                       name=None,
                       temperature=1.0):
    """Generate a search space specification for an RL controller model.

  Each OneOf value is sampled independently of every other; hence the name.

  Args:
    structure: Nested data structure containing OneOf objects to search over.
    increase_ops_probability: Scalar float Tensor or None. If not None, we will
        randomly enable all possible operations instead of just the selected
        operations with this probability.
    increase_filters_probability: Scalar float Tensor or None. If not None, we
        will randomly use the largest possible filter sizes with this
        probability.
    hierarchical: Boolean. If true, the values of the outputs `sample_log_prob`
        and `entropy` will only take into account subgraphs that are enabled at
        the current training step.
    name: Optional name for the newly created TensorFlow scope.
    temperature: Positive scalar controlling the temperature to use when
        sampling from the RL controller.

  Returns:
    A tuple (new_structure, dist_info) where `new_structure` is a copy
    of `structure` annotated with mask tensors, and `dist_info` is a
    dictionary containing information about the sampling distribution
    which contains the following keys:
      - entropy: Scalar float Tensor, entropy of the current probability
            distribution.
      - logits_by_path: OrderedDict of rank-1 Tensors, sample-independent logits
            for each OneOf in `structure`. Names are derived from OneOf paths.
      - logits_by_tag: OrderedDict of rank-1 Tensors, sample-independent logits
            for each OneOf in `structure`. Names are derived from OneOf tags.
      - sample_log_prob: Scalar float Tensor, log-probability of the current
            sample associated with `new_structure`.
  """
    with tf.variable_scope(name, 'independent_sample'):
        temperature = tf.convert_to_tensor(temperature, tf.float32)
        dist_info = {
            'entropy': tf.constant(0, tf.float32),
            'logits_by_path': collections.OrderedDict(),
            'logits_by_tag': collections.OrderedDict(),
            'sample_log_prob': tf.constant(0, tf.float32),
        }
        tag_counters = collections.Counter()

        entropies = dict()
        log_probs = dict()
        is_active = dict()

        def visit(tuple_path, oneof):
            """Visit a OneOf node in `structure`."""
            string_path = '/'.join(map(str, tuple_path))
            num_choices = len(oneof.choices)

            logits = tf.get_variable(name='logits/' + string_path,
                                     initializer=tf.initializers.zeros(),
                                     shape=[num_choices],
                                     dtype=tf.float32)
            logits = logits / temperature

            tag_name = '{:s}_{:d}'.format(oneof.tag, tag_counters[oneof.tag])
            tag_counters[oneof.tag] += 1

            dist_info['logits_by_path'][string_path] = logits
            dist_info['logits_by_tag'][tag_name] = logits

            dist = tfp.distributions.OneHotCategorical(logits=logits,
                                                       dtype=tf.float32)
            entropies[tuple_path] = dist.entropy()

            sample_mask = dist.sample()
            sample_log_prob = dist.log_prob(sample_mask)
            if oneof.tag == basic_specs.OP_TAG:
                sample_mask, sample_log_prob = _replace_sample_with_probability(
                    sample_mask, sample_log_prob, increase_ops_probability,
                    tf.constant([1.0 / num_choices] * num_choices, tf.float32))
            elif oneof.tag == basic_specs.FILTERS_TAG:
                # NOTE: While np.argmax() was originally designed to work with integer
                # filter sizes, it will also work with any object type that supports
                # "less than" and "greater than" operations.
                sample_mask, sample_log_prob = _replace_sample_with_probability(
                    sample_mask, sample_log_prob, increase_filters_probability,
                    tf.one_hot(np.argmax(oneof.choices), len(oneof.choices)))

            log_probs[tuple_path] = sample_log_prob
            for i in range(len(oneof.choices)):
                tuple_subpath = tuple_path + ('choices', i)
                is_active[tuple_subpath] = tf.greater(tf.abs(sample_mask[i]),
                                                      1e-6)

            return schema.OneOf(choices=oneof.choices,
                                tag=oneof.tag,
                                mask=sample_mask)

        new_structure = schema.map_oneofs_with_tuple_paths(visit, structure)

        assert six.viewkeys(entropies) == six.viewkeys(log_probs)
        for path in entropies:
            path_is_active = tf.constant(True)
            if hierarchical:
                for i in range(len(path) + 1):
                    if path[:i] in is_active:
                        path_is_active = tf.logical_and(
                            path_is_active, is_active[path[:i]])

            path_is_active = tf.cast(path_is_active, tf.float32)
            dist_info['entropy'] += entropies[path] * path_is_active
            dist_info['sample_log_prob'] += log_probs[path] * path_is_active

        return (new_structure, dist_info)
コード例 #56
0
 def countries(self):
     """A set-like object of the country codes supplied by this reader.
     """
     return viewkeys(self._readers)
コード例 #57
0
ファイル: module.py プロジェクト: pombredanne/progressivis
 def _filter_kwds(kwds, function_or_method):
     argspec = getfullargspec(function_or_method)
     keys_ = argspec.args[len(argspec.args) - (
         0 if argspec.defaults is None else len(argspec.defaults)):]
     filtered_kwds = {k: kwds[k] for k in six.viewkeys(kwds) & keys_}
     return filtered_kwds
def visualize_boxes_and_labels_on_image_array(
        file_name,
        image,
        boxes,
        classes,
        scores,
        category_index,
        instance_masks=None,
        instance_boundaries=None,
        keypoints=None,
        track_ids=None,
        use_normalized_coordinates=False,
        max_boxes_to_draw=20,
        min_score_thresh=.5,
        agnostic_mode=False,
        line_thickness=4,
        groundtruth_box_visualization_color='black',
        skip_scores=False,
        skip_labels=False,
        skip_track_ids=False):
    """Overlay labeled boxes on an image with formatted scores and label names.

    This function groups boxes that correspond to the same location
    and creates a display string for each detection and overlays these
    on the image. Note that this function modifies the image in place, and returns
    that same image.

    Args:
      image: uint8 numpy array with shape (img_height, img_width, 3)
      boxes: a numpy array of shape [N, 4]
      classes: a numpy array of shape [N]. Note that class indices are 1-based,
        and match the keys in the label map.
      scores: a numpy array of shape [N] or None.  If scores=None, then
        this function assumes that the boxes to be plotted are groundtruth
        boxes and plot all boxes as black with no classes or scores.
      category_index: a dict containing category dictionaries (each holding
        category index `id` and category name `name`) keyed by category indices.
      instance_masks: a numpy array of shape [N, image_height, image_width] with
        values ranging between 0 and 1, can be None.
      instance_boundaries: a numpy array of shape [N, image_height, image_width]
        with values ranging between 0 and 1, can be None.
      keypoints: a numpy array of shape [N, num_keypoints, 2], can
        be None
      track_ids: a numpy array of shape [N] with unique track ids. If provided,
        color-coding of boxes will be determined by these ids, and not the class
        indices.
      use_normalized_coordinates: whether boxes is to be interpreted as
        normalized coordinates or not.
      max_boxes_to_draw: maximum number of boxes to visualize.  If None, draw
        all boxes.
      min_score_thresh: minimum score threshold for a box to be visualized
      agnostic_mode: boolean (default: False) controlling whether to evaluate in
        class-agnostic mode or not.  This mode will display scores but ignore
        classes.
      line_thickness: integer (default: 4) controlling line width of the boxes.
      groundtruth_box_visualization_color: box color for visualizing groundtruth
        boxes
      skip_scores: whether to skip score when drawing a single detection
      skip_labels: whether to skip label when drawing a single detection
      skip_track_ids: whether to skip track id when drawing a single detection

    Returns:
      uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes.
    """
    # Create a display string (and color) for every box location, group any boxes
    # that correspond to the same location.
    new_xml = False
    classes_names = {}

    box_to_display_str_map = collections.defaultdict(list)
    box_to_color_map = collections.defaultdict(str)
    box_to_instance_masks_map = {}
    box_to_instance_boundaries_map = {}
    box_to_keypoints_map = collections.defaultdict(list)
    box_to_track_ids_map = {}
    if not max_boxes_to_draw:
        max_boxes_to_draw = boxes.shape[0]
    for i in range(min(max_boxes_to_draw, boxes.shape[0])):
        if scores is None or scores[i] > min_score_thresh:
            box = tuple(boxes[i].tolist())
            if instance_masks is not None:
                box_to_instance_masks_map[box] = instance_masks[i]
            if instance_boundaries is not None:
                box_to_instance_boundaries_map[box] = instance_boundaries[i]
            if keypoints is not None:
                box_to_keypoints_map[box].extend(keypoints[i])
            if track_ids is not None:
                box_to_track_ids_map[box] = track_ids[i]
            if scores is None:
                box_to_color_map[box] = groundtruth_box_visualization_color
            else:
                display_str = ''
                if not skip_labels:
                    if not agnostic_mode:
                        if classes[i] in six.viewkeys(category_index):
                            class_name = category_index[classes[i]]['name']
                            classes_names[i] = str(class_name)

                            new_xml = True
                        else:
                            class_name = 'N/A'
                        display_str = str(class_name)
                if not skip_scores:
                    if not display_str:
                        display_str = '{}%'.format(int(100 * scores[i]))
                    else:
                        display_str = '{}: {}%'.format(display_str,
                                                       int(100 * scores[i]))
                if not skip_track_ids and track_ids is not None:
                    if not display_str:
                        display_str = 'ID {}'.format(track_ids[i])
                    else:
                        display_str = '{}: ID {}'.format(
                            display_str, track_ids[i])
                box_to_display_str_map[box].append(display_str)
                if agnostic_mode:
                    box_to_color_map[box] = 'DarkOrange'
                elif track_ids is not None:
                    prime_multipler = _get_multiplier_for_color_randomness()
                    box_to_color_map[box] = STANDARD_COLORS[
                        (prime_multipler * track_ids[i]) %
                        len(STANDARD_COLORS)]
                else:
                    box_to_color_map[box] = STANDARD_COLORS[
                        classes[i] % len(STANDARD_COLORS)]

    array_position = []
    im_height, im_width, shape = image.shape

    # Draw all boxes onto image.
    for box, color in box_to_color_map.items():
        ymin, xmin, ymax, xmax = box
        dict_position = {'xmin': 0, 'xmax': 0, 'ymin': 0, 'ymax': 0}

        dict_position['ymin'] = ymin * im_height
        dict_position['xmin'] = xmin * im_width
        dict_position['ymax'] = ymax * im_height
        dict_position['xmax'] = xmax * im_width

        array_position.append(dict_position)

        if instance_masks is not None:
            draw_mask_on_image_array(image,
                                     box_to_instance_masks_map[box],
                                     color=color)
        if instance_boundaries is not None:
            draw_mask_on_image_array(image,
                                     box_to_instance_boundaries_map[box],
                                     color='red',
                                     alpha=1.0)
        draw_bounding_box_on_image_array(
            image,
            ymin,
            xmin,
            ymax,
            xmax,
            color=color,
            thickness=line_thickness,
            display_str_list=box_to_display_str_map[box],
            use_normalized_coordinates=use_normalized_coordinates)
        if keypoints is not None:
            draw_keypoints_on_image_array(
                image,
                box_to_keypoints_map[box],
                color=color,
                radius=line_thickness / 2,
                use_normalized_coordinates=use_normalized_coordinates)

    if new_xml != False:
        xml = generate_xml.GenerateXml(array_position, im_width, im_height,
                                       classes_names, file_name)
        xml.gerenate_basic_structure()

    return image
コード例 #59
0
 def _update_config(self):
     self._set_config(six.viewkeys(self._action_handlers))
コード例 #60
0
    def create_evaluation_lists(self):

        # return reference
        status_OUT = None

        # declare variables
        citation_map = None
        baseline_list = None
        derived_binary_list = None
        derived_raw_list = None
        cutoff_value = None
        publication_id_list = None
        publication_id = None
        publication_dict = None
        data_set_map = None
        data_set_id_list = None
        data_set_id = None
        data_set_found_map = None
        baseline_score = -1
        derived_score = -1

        # get citation_map
        citation_map = self.get_citation_map()

        # init lists
        baseline_list = self.set_baseline_list([])
        derived_binary_list = self.set_derived_binary_list([])
        derived_raw_list = self.set_derived_raw_list([])

        # cutoffs
        cutoff_value = self.get_cutoff()

        # so we can get publication ID list
        publication_id_list = list(six.viewkeys(citation_map))
        publication_id_list.sort()

        # loop over publications, and then data sets within.
        for publication_id in publication_id_list:

            # DEBUG
            if (self.debug_flag == True):
                print("Publication ID: {}".format(publication_id))
            #-- END DEBUG --#

            # get publication map
            publication_dict = citation_map.get(publication_id, None)

            # get the data set map and ID list.
            data_set_map = publication_dict.get(self.JSON_NAME_DATA_SET_MAP,
                                                None)
            data_set_id_list = list(six.viewkeys(data_set_map))
            data_set_id_list.sort()

            # loop over data set ID list.
            for data_set_id in data_set_id_list:

                # DEBUG
                if (self.debug_flag == True):
                    print("==> Data Set ID: {}".format(data_set_id))
                #-- END DEBUG --#

                # get the data_set_found_map
                data_set_found_map = data_set_map.get(data_set_id, None)

                # get the scores.
                baseline_score = data_set_found_map.get(
                    self.RESULT_TYPE_BASELINE, 0.0)
                derived_score = data_set_found_map.get(
                    self.RESULT_TYPE_DERIVED, 0.0)

                # DEBUG
                if (self.debug_flag == True):
                    print("            baseline: {}".format(baseline_score))
                    print("            derived.: {}".format(derived_score))
                #-- END DEBUG --#

                # add them to the lists
                baseline_list.append(baseline_score)
                derived_raw_list.append(derived_score)
                if derived_score > cutoff_value:
                    derived_binary_list.append(1.0)
                else:
                    derived_binary_list.append(0.0)
                #-- END binary value assignment --#

            #-- END loop over data set IDs. --#

        #-- END loop over publication IDs. --#

        return status_OUT