コード例 #1
0
ファイル: parameter.py プロジェクト: pcowpert/MOSFiT
 def __init__(self, **kwargs):
     """Initialize module."""
     super(Parameter, self).__init__(**kwargs)
     self._fixed = kwargs.get('fixed', False)
     self._fixed_by_user = False
     self._max_value = kwargs.get('max_value', None)
     self._min_value = kwargs.get('min_value', None)
     if (self._min_value is not None and self._max_value is not None
             and self._min_value == self._max_value):
         self._printer.message('min_max_same', [self._name], warning=True)
         self._value = self._min_value
         self._min_value, self._max_value = None, None
         self._fixed = True
         self._fixed_by_user = True
     if self._min_value is None or self._max_value is None:
         self._fixed = True
         self._fixed_by_user = True
     self._value = kwargs.get('value', None)
     self._log = kwargs.get('log', False)
     self._latex = kwargs.get('latex', self._name)
     self._derived_keys = listify(kwargs.get(
         'derived_keys', [])) + ['reference_' + self._name]
     if (self._log and self._min_value is not None
             and self._max_value is not None):
         if self._min_value <= 0.0 or self._max_value <= 0.0:
             raise ValueError(
                 'Parameter with log prior cannot have range values <= 0!')
         self._min_value = np.log(self._min_value)
         self._max_value = np.log(self._max_value)
     self._reference_value = None
     self._clipped_warning = False
コード例 #2
0
ファイル: model.py プロジェクト: slowdivePTG/MOSFiT
 def construct_trees(self,
                     d,
                     trees,
                     simple,
                     kinds=[],
                     name='',
                     roots=[],
                     depth=0):
     """Construct call trees for each root."""
     leaf = kinds if len(kinds) else name
     if depth > 100:
         raise RuntimeError(
             'Error: Tree depth greater than 100, suggests a recursive '
             'input loop in `{}`.'.format(leaf))
     for tag in d:
         entry = deepcopy(d[tag])
         new_roots = list(roots)
         if entry['kind'] in kinds or tag == name:
             entry['depth'] = depth
             if entry['kind'] in kinds:
                 new_roots.append(entry['kind'])
             entry['roots'] = list(sorted(set(new_roots)))
             trees[tag] = entry
             simple[tag] = OrderedDict()
             inputs = listify(entry.get('inputs', []))
             for inps in inputs:
                 conditional = False
                 if isinstance(inps, list) and not isinstance(
                         inps, string_types) and inps[-1] == "conditional":
                     inp = inps[0]
                     conditional = True
                 else:
                     inp = inps
                 if inp not in d:
                     suggests = get_close_matches(inp, d, n=1, cutoff=0.8)
                     warn_str = ('Module `{}` for input to `{}` '
                                 'not found!'.format(inp, leaf))
                     if len(suggests):
                         warn_str += (' Did you perhaps mean `{}`?'.format(
                             suggests[0]))
                     raise RuntimeError(warn_str)
                 # Conditional inputs don't propagate down the tree.
                 if conditional:
                     continue
                 children = OrderedDict()
                 simple_children = OrderedDict()
                 self.construct_trees(d,
                                      children,
                                      simple_children,
                                      name=inp,
                                      roots=new_roots,
                                      depth=depth + 1)
                 trees[tag].setdefault('children', OrderedDict())
                 trees[tag]['children'].update(children)
                 simple[tag].update(simple_children)
コード例 #3
0
ファイル: fitter.py プロジェクト: Hoptune/MOSFiT
    def fit_events(self,
                   events=[],
                   models=[],
                   max_time='',
                   time_list=[],
                   time_unit=None,
                   band_list=[],
                   band_systems=[],
                   band_instruments=[],
                   band_bandsets=[],
                   band_sampling_points=17,
                   iterations=10000,
                   num_walkers=None,
                   num_temps=1,
                   parameter_paths=['parameters.json'],
                   fracking=True,
                   frack_step=50,
                   burn=None,
                   post_burn=None,
                   gibbs=False,
                   smooth_times=-1,
                   extrapolate_time=0.0,
                   limit_fitting_mjds=False,
                   exclude_bands=[],
                   exclude_instruments=[],
                   exclude_systems=[],
                   exclude_sources=[],
                   exclude_kinds=[],
                   output_path='',
                   suffix='',
                   upload=False,
                   write=False,
                   upload_token='',
                   check_upload_quality=False,
                   variance_for_each=[],
                   user_fixed_parameters=[],
                   user_released_parameters=[],
                   convergence_type=None,
                   convergence_criteria=None,
                   save_full_chain=False,
                   draw_above_likelihood=False,
                   maximum_walltime=False,
                   start_time=False,
                   print_trees=False,
                   maximum_memory=np.inf,
                   speak=False,
                   return_fits=True,
                   extra_outputs=None,
                   walker_paths=[],
                   catalogs=[],
                   exit_on_prompt=False,
                   download_recommended_data=False,
                   local_data_only=False,
                   method=None,
                   seed=None,
                   **kwargs):
        """Fit a list of events with a list of models."""
        global model
        if start_time is False:
            start_time = time.time()

        self._seed = seed
        if seed is not None:
            np.random.seed(seed)

        self._start_time = start_time
        self._maximum_walltime = maximum_walltime
        self._maximum_memory = maximum_memory
        self._debug = False
        self._speak = speak
        self._download_recommended_data = download_recommended_data
        self._local_data_only = local_data_only

        self._draw_above_likelihood = draw_above_likelihood

        prt = self._printer

        event_list = listify(events)
        model_list = listify(models)

        if len(model_list) and not len(event_list):
            event_list = ['']

        # Exclude catalogs not included in catalog list.
        self._fetcher.add_excluded_catalogs(catalogs)

        if not len(event_list) and not len(model_list):
            prt.message('no_events_models', warning=True)

        # If the input is not a JSON file, assume it is either a list of
        # transients or that it is the data from a single transient in tabular
        # form. Try to guess the format first, and if that fails ask the user.
        self._converter = Converter(prt, require_source=upload)
        event_list = self._converter.generate_event_list(event_list)

        event_list = [x.replace('‑', '-') for x in event_list]

        entries = [[] for x in range(len(event_list))]
        ps = [[] for x in range(len(event_list))]
        lnprobs = [[] for x in range(len(event_list))]

        # Load walker data if provided a list of walker paths.
        walker_data = []

        if len(walker_paths):
            try:
                pool = MPIPool()
            except (ImportError, ValueError):
                pool = SerialPool()
            if pool.is_master():
                prt.message('walker_file')
                wfi = 0
                for walker_path in walker_paths:
                    if os.path.exists(walker_path):
                        prt.prt('  {}'.format(walker_path))
                        with codecs.open(walker_path, 'r',
                                         encoding='utf-8') as f:
                            all_walker_data = json.load(
                                f, object_pairs_hook=OrderedDict)

                        # Support both the format where all data stored in a
                        # single-item dictionary (the OAC format) and the older
                        # MOSFiT format where the data was stored in the
                        # top-level dictionary.
                        if ENTRY.NAME not in all_walker_data:
                            all_walker_data = all_walker_data[list(
                                all_walker_data.keys())[0]]

                        models = all_walker_data.get(ENTRY.MODELS, [])
                        choice = None
                        if len(models) > 1:
                            model_opts = [
                                '{}-{}-{}'.format(x['code'], x['name'],
                                                  x['date']) for x in models
                            ]
                            choice = prt.prompt('select_model_walkers',
                                                kind='select',
                                                message=True,
                                                options=model_opts)
                            choice = model_opts.index(choice)
                        elif len(models) == 1:
                            choice = 0

                        if choice is not None:
                            walker_data.extend([[
                                wfi, x[REALIZATION.PARAMETERS],
                                x.get(REALIZATION.WEIGHT)
                            ] for x in models[choice][MODEL.REALIZATIONS]])

                        for i in range(len(walker_data)):
                            if walker_data[i][2] is not None:
                                walker_data[i][2] = float(walker_data[i][2])

                        if not len(walker_data):
                            prt.message('no_walker_data')
                    else:
                        prt.message('no_walker_data')
                        if self._offline:
                            prt.message('omit_offline')
                        raise RuntimeError
                    wfi = wfi + 1

                for rank in range(1, pool.size + 1):
                    pool.comm.send(walker_data, dest=rank, tag=3)
            else:
                walker_data = pool.comm.recv(source=0, tag=3)
                pool.wait()

            if pool.is_master():
                pool.close()

        self._event_name = 'Batch'
        self._event_path = ''
        self._event_data = {}

        try:
            pool = MPIPool()
        except (ImportError, ValueError):
            pool = SerialPool()
        if pool.is_master():
            fetched_events = self._fetcher.fetch(
                event_list,
                offline=self._offline,
                prefer_cache=self._prefer_cache)

            for rank in range(1, pool.size + 1):
                pool.comm.send(fetched_events, dest=rank, tag=0)
            pool.close()
        else:
            fetched_events = pool.comm.recv(source=0, tag=0)
            pool.wait()

        for ei, event in enumerate(fetched_events):
            if event is not None:
                self._event_name = event.get('name', 'Batch')
                self._event_path = event.get('path', '')
                if not self._event_path:
                    continue
                self._event_data = self._fetcher.load_data(event)
                if not self._event_data:
                    continue

            if model_list:
                lmodel_list = model_list
            else:
                lmodel_list = ['']

            entries[ei] = [None for y in range(len(lmodel_list))]
            ps[ei] = [None for y in range(len(lmodel_list))]
            lnprobs[ei] = [None for y in range(len(lmodel_list))]

            if (event is not None and
                (not self._event_data or ENTRY.PHOTOMETRY
                 not in self._event_data[list(self._event_data.keys())[0]])):
                prt.message('no_photometry', [self._event_name])
                continue

            for mi, mod_name in enumerate(lmodel_list):
                for parameter_path in parameter_paths:
                    try:
                        pool = MPIPool()
                    except (ImportError, ValueError):
                        pool = SerialPool()
                    self._model = Model(model=mod_name,
                                        data=self._event_data,
                                        parameter_path=parameter_path,
                                        output_path=output_path,
                                        wrap_length=self._wrap_length,
                                        test=self._test,
                                        printer=prt,
                                        fitter=self,
                                        pool=pool,
                                        print_trees=print_trees)

                    if not self._model._model_name:
                        prt.message('no_models_avail', [self._event_name],
                                    warning=True)
                        continue

                    if not event:
                        prt.message('gen_dummy')
                        self._event_name = mod_name
                        gen_args = {
                            'name': mod_name,
                            'max_time': max_time,
                            'time_list': time_list,
                            'band_list': band_list,
                            'band_systems': band_systems,
                            'band_instruments': band_instruments,
                            'band_bandsets': band_bandsets
                        }
                        self._event_data = self.generate_dummy_data(**gen_args)

                    success = False
                    alt_name = None
                    while not success:
                        self._model.reset_unset_recommended_keys()
                        success = self._model.load_data(
                            self._event_data,
                            event_name=self._event_name,
                            smooth_times=smooth_times,
                            extrapolate_time=extrapolate_time,
                            limit_fitting_mjds=limit_fitting_mjds,
                            exclude_bands=exclude_bands,
                            exclude_instruments=exclude_instruments,
                            exclude_systems=exclude_systems,
                            exclude_sources=exclude_sources,
                            exclude_kinds=exclude_kinds,
                            time_list=time_list,
                            time_unit=time_unit,
                            band_list=band_list,
                            band_systems=band_systems,
                            band_instruments=band_instruments,
                            band_bandsets=band_bandsets,
                            band_sampling_points=band_sampling_points,
                            variance_for_each=variance_for_each,
                            user_fixed_parameters=user_fixed_parameters,
                            user_released_parameters=user_released_parameters,
                            pool=pool)

                        if not success:
                            break

                        if self._local_data_only:
                            break

                        # If our data is missing recommended keys, offer the
                        # user option to pull the missing data from online and
                        # merge with existing data.
                        urk = self._model.get_unset_recommended_keys()
                        ptxt = prt.text('acquire_recommended',
                                        [', '.join(list(urk))])
                        while event and len(urk) and (
                                alt_name or self._download_recommended_data
                                or prt.prompt(ptxt, [', '.join(urk)],
                                              kind='bool')):
                            try:
                                pool = MPIPool()
                            except (ImportError, ValueError):
                                pool = SerialPool()
                            if pool.is_master():
                                en = (alt_name
                                      if alt_name else self._event_name)
                                extra_event = self._fetcher.fetch(
                                    en,
                                    offline=self._offline,
                                    prefer_cache=self._prefer_cache)[0]
                                extra_data = self._fetcher.load_data(
                                    extra_event)

                                for rank in range(1, pool.size + 1):
                                    pool.comm.send(extra_data,
                                                   dest=rank,
                                                   tag=4)
                                pool.close()
                            else:
                                extra_data = pool.comm.recv(source=0, tag=4)
                                pool.wait()

                            if extra_data is not None:
                                extra_data = extra_data[list(
                                    extra_data.keys())[0]]

                                for key in urk:
                                    new_val = extra_data.get(key)
                                    self._event_data[list(
                                        self._event_data.keys())
                                                     [0]][key] = new_val
                                    if new_val is not None and len(new_val):
                                        prt.message('extra_value', [
                                            key,
                                            str(new_val[0].get(QUANTITY.VALUE))
                                        ])
                                success = False
                                prt.message('reloading_merged')
                                break
                            else:
                                text = prt.text('extra_not_found',
                                                [self._event_name])
                                alt_name = prt.prompt(text, kind='string')
                                if not alt_name:
                                    break

                    if success:
                        self._walker_data = walker_data

                        entry, p, lnprob = self.fit_data(
                            event_name=self._event_name,
                            method=method,
                            iterations=iterations,
                            num_walkers=num_walkers,
                            num_temps=num_temps,
                            burn=burn,
                            post_burn=post_burn,
                            fracking=fracking,
                            frack_step=frack_step,
                            gibbs=gibbs,
                            pool=pool,
                            output_path=output_path,
                            suffix=suffix,
                            write=write,
                            upload=upload,
                            upload_token=upload_token,
                            check_upload_quality=check_upload_quality,
                            convergence_type=convergence_type,
                            convergence_criteria=convergence_criteria,
                            save_full_chain=save_full_chain,
                            extra_outputs=extra_outputs)
                        if return_fits:
                            entries[ei][mi] = deepcopy(entry)
                            ps[ei][mi] = deepcopy(p)
                            lnprobs[ei][mi] = deepcopy(lnprob)

                    if pool.is_master():
                        pool.close()

                    # Remove global model variable and garbage collect.
                    try:
                        model
                    except NameError:
                        pass
                    else:
                        del (model)
                    del (self._model)
                    gc.collect()

        return (entries, ps, lnprobs)
コード例 #4
0
    def __init__(self, **kwargs):
        """Initialize module."""
        super(Photometry, self).__init__(**kwargs)

        bands = kwargs.get('bands', '')
        bands = listify(bands)

        self._dir_path = os.path.dirname(os.path.realpath(__file__))
        band_list = []

        if self._pool.is_master():
            with open(os.path.join(self._dir_path, 'filterrules.json')) as f:
                filterrules = json.load(f, object_pairs_hook=OrderedDict)
            for rank in range(1, self._pool.size + 1):
                self._pool.comm.send(filterrules, dest=rank, tag=5)
        else:
            filterrules = self._pool.comm.recv(source=0, tag=5)

        for bi, band in enumerate(bands):
            for rule in filterrules:
                if '@note' in rule:
                    continue
                sysinstperms = [
                    {
                        'systems': xx,
                        'instruments': yy,
                        'bandsets': zz,
                        'telescopes': tt,
                        'modes': mm
                    }
                    for xx in rule.get('systems', [''])
                    for yy in rule.get('instruments', [''])
                    for zz in rule.get('bandsets', [''])
                    for tt in rule.get('telescopes', [''])
                    for mm in rule.get('modes', [''])
                ]
                for bnd in rule.get('filters', []):
                    if band == bnd or band == '':
                        for perm in sysinstperms:
                            new_band = deepcopy(rule['filters'][bnd])
                            new_band.update(deepcopy(perm))
                            new_band['name'] = bnd
                            band_list.append(new_band)

        self._unique_bands = band_list
        self._band_insts = np.array(
            [x['instruments'] for x in self._unique_bands], dtype=object)
        self._band_bsets = np.array(
            [x['bandsets'] for x in self._unique_bands], dtype=object)
        self._band_systs = np.array(
            [x['systems'] for x in self._unique_bands], dtype=object)
        self._band_teles = np.array(
            [x['telescopes'] for x in self._unique_bands], dtype=object)
        self._band_modes = np.array(
            [x['modes'] for x in self._unique_bands], dtype=object)
        self._band_names = np.array(
            [x['name'] for x in self._unique_bands], dtype=object)
        self._n_bands = len(self._unique_bands)
        self._band_wavelengths = [[] for i in range(self._n_bands)]
        self._band_energies = [[] for i in range(self._n_bands)]
        self._transmissions = [[] for i in range(self._n_bands)]
        self._band_areas = [[] for i in range(self._n_bands)]
        self._min_waves = np.full(self._n_bands, 0.0)
        self._max_waves = np.full(self._n_bands, 0.0)
        self._imp_waves = [[0.0, 1.0] for i in range(self._n_bands)]
        self._filter_integrals = np.full(self._n_bands, 0.0)
        self._average_wavelengths = np.full(self._n_bands, 0.0)
        self._band_offsets = np.full(self._n_bands, 0.0)
        self._band_xunits = np.full(self._n_bands, 'Angstrom', dtype=object)
        self._band_yunits = np.full(self._n_bands, '', dtype=object)
        self._band_xu = np.full(self._n_bands, u.Angstrom.cgs.scale)
        self._band_yu = np.full(self._n_bands, 1.0)
        self._band_kinds = np.full(self._n_bands, 'magnitude', dtype=object)
        self._band_index_cache = {}

        for i, band in enumerate(self._unique_bands):
            self._band_xunits[i] = band.get('xunit', 'Angstrom')
            self._band_yunits[i] = band.get('yunit', '')
            self._band_xu[i] = u.Unit(self._band_xunits[i]).cgs.scale
            self._band_yu[i] = u.Unit(self._band_yunits[i]).cgs.scale

            if '{0}'.format(self._band_yunits[i]) == 'cm2':
                self._band_kinds[i] = 'countrate'
コード例 #5
0
    def set_data(self,
                 all_data,
                 req_key_values={},
                 subtract_minimum_keys=[],
                 smooth_times=-1,
                 extrapolate_time=0.0,
                 limit_fitting_mjds=False,
                 exclude_bands=[],
                 exclude_instruments=[],
                 exclude_systems=[],
                 exclude_sources=[],
                 exclude_kinds=[],
                 time_unit=None,
                 time_list=[],
                 band_list=[],
                 band_telescopes=[],
                 band_systems=[],
                 band_instruments=[],
                 band_modes=[],
                 band_bandsets=[]):
        """Set transient data."""
        prt = self._printer

        self._all_data = all_data
        self._data = OrderedDict()
        if not self._all_data:
            return
        name = list(self._all_data.keys())[0]
        self._data['name'] = name
        numeric_keys = set()

        ex_kinds = [
            self._EX_REPS.get(x.lower(), x.lower()) for x in exclude_kinds
        ]

        # Construct some source dictionaries for exclusion rules
        src_dict = OrderedDict()
        sources = self._all_data[name].get('sources', [])
        for src in sources:
            if SOURCE.BIBCODE in src:
                src_dict[src[SOURCE.ALIAS]] = src[SOURCE.BIBCODE]
            if SOURCE.ARXIVID in src:
                src_dict[src[SOURCE.ALIAS]] = src[SOURCE.ARXIVID]
            if SOURCE.NAME in src:
                src_dict[src[SOURCE.ALIAS]] = src[SOURCE.NAME]

        for key in self._keys:
            subkeys = self._keys[key]
            req_subkeys = [
                x for x in subkeys if not isinstance(subkeys, dict)
                or 'required' in listify(subkeys[x])
            ]
            num_subkeys = [
                x for x in subkeys if 'numeric' in listify(subkeys[x])
            ]
            boo_subkeys = [
                x for x in subkeys if 'boolean' in listify(subkeys[x])
            ]
            exc_subkeys = [
                x for x in subkeys if 'exclude' in listify(subkeys[x])
            ]

            if (key not in self._all_data[name]
                    and not self._model.is_parameter_fixed_by_user(key)):
                if subkeys.get('value', None) == 'recommended':
                    self._unset_recommended_keys.add(key)
                continue

            subdata = self._all_data[name].get(key)

            if subdata is None:
                continue

            # Only include data that contains all subkeys
            for entry in subdata:
                if any([x not in entry for x in req_subkeys]):
                    continue
                if any([x in entry for x in exc_subkeys]):
                    continue
                if any([
                        x in entry and
                    ((isinstance(entry[x], list) and any([
                        not is_number(y) or np.isnan(float(y))
                        for y in entry[x]
                    ])) or not isinstance(entry[x], list) and
                     (not is_number(entry[x]) or np.isnan(float(entry[x]))))
                        for x in num_subkeys
                ]):
                    continue

                skip_key = False
                if 'frequency' not in entry:
                    for qkey in req_key_values:
                        if qkey in entry and entry[qkey] != '':
                            if entry[qkey] not in req_key_values[qkey]:
                                skip_key = True
                            break

                if key == 'photometry':
                    if ('fluxdensity' in entry and 'magnitude' not in entry
                            and 'countrate' not in entry):
                        self._kinds_needed.add('radio')
                        if ('radio' in ex_kinds or
                            (not len(ex_kinds) or 'none' not in ex_kinds) and
                                'radio' not in self._model._kinds_supported):
                            continue
                    if (('countrate' in entry or 'unabsorbedflux' in entry
                         or 'flux' in entry) and 'magnitude' not in entry
                            and 'fluxdensity' not in entry):
                        self._kinds_needed.add('x-ray')
                        if ('x-ray' in ex_kinds or
                            (not len(ex_kinds) or 'none' not in ex_kinds) and
                                'x-ray' not in self._model._kinds_supported):
                            continue
                    if 'magnitude' in entry:
                        # For now, magnitudes are not excludable.
                        self._kinds_needed |= set(
                            ['infrared', 'optical', 'ultraviolet'])

                    skip_entry = False

                    for x in subkeys:
                        if limit_fitting_mjds is not False and x == 'time':
                            val = np.mean([
                                float(x) for x in listify(entry.get(x, None))
                            ])
                            if (val < limit_fitting_mjds[0]
                                    or val > limit_fitting_mjds[1]):
                                skip_entry = True
                                break
                        if exclude_bands is not False and x == 'band':
                            if (entry.get(x, '') in exclude_bands and
                                (not exclude_instruments or entry.get(
                                    'instrument', '') in exclude_instruments)
                                    and (not exclude_systems or entry.get(
                                        'system', '') in exclude_systems)):
                                skip_entry = True
                                break
                        if (exclude_instruments is not False
                                and x == 'instrument'):
                            if (entry.get(x, '') in exclude_instruments and
                                (not exclude_bands
                                 or entry.get('band', '') in exclude_bands)
                                    and (not exclude_systems or entry.get(
                                        'system', '') in exclude_systems)):
                                skip_entry = True
                                break
                        if (exclude_systems is not False and x == 'system'):
                            if (entry.get(x, '') in exclude_systems and
                                (not exclude_bands
                                 or entry.get('band', '') in exclude_bands) and
                                (not exclude_instruments or entry.get(
                                    'instrument', '') in exclude_instruments)):
                                skip_entry = True
                                break
                        if (exclude_sources is not False and x == 'source'):
                            val = entry.get(x, '')
                            if (any(
                                [x in exclude_sources for x in val.split(',')])
                                    or any([
                                        src_dict.get(x, '') in exclude_sources
                                        for x in val.split(',')
                                    ])):
                                skip_entry = True
                                break
                    if skip_entry:
                        continue

                    if ((('magnitude' in entry) !=
                         ('band' in entry)) or ((('fluxdensity' in entry) !=
                                                 ('frequency' in entry)) and
                                                ('magnitude' not in entry))
                            or (('countrate' in entry) and
                                ('magnitude' not in entry) and
                                ('instrument' not in entry))):
                        continue

                for x in subkeys:
                    falseval = (False if x in boo_subkeys else
                                None if x in num_subkeys else '')
                    if x == 'value':
                        if not skip_key:
                            self._data[key] = entry.get(x, falseval)
                    else:
                        plural = self._model.plural(x)
                        val = entry.get(x, falseval)
                        if x in num_subkeys:
                            val = None if val is None else np.mean(
                                [float(x) for x in listify(val)])
                        if not skip_key:
                            self._data.setdefault(plural, []).append(val)
                            if x in num_subkeys:
                                numeric_keys.add(plural)
                        else:
                            self._data.setdefault('unmatched_' + plural,
                                                  []).append(val)

        if 'times' not in self._data or not any([
                x in self._data
                for x in ['magnitudes', 'frequencies', 'countrates']
        ]):
            prt.message('no_fittable_data', [name])
            return False

        for key in list(self._data.keys()):
            if isinstance(self._data[key], list):
                self._data[key] = np.array(self._data[key])
                if key not in numeric_keys:
                    continue
                num_values = [
                    x for x in self._data[key] if isinstance(x, float)
                ]
                if len(num_values):
                    self._data['min_' + key] = min(num_values)
                    self._data['max_' + key] = max(num_values)
            else:
                if is_number(self._data[key]):
                    self._data[key] = float(self._data[key])
                    self._data_determined_parameters.append(key)

        if any(x in self._data
               for x in ['magnitudes', 'countrates', 'fluxdensities']):
            # Add a list of tags for each observation to indicate what unit
            # observation is provided in.
            self._data['measures'] = [
                ((['magnitude'] if x else []) + (['countrate'] if y else []) +
                 (['fluxdensity'] if x else []))
                for x, y, z in zip(*(self._data['magnitudes'],
                                     self._data['countrates'],
                                     self._data['fluxdensities']))
            ]

        if 'times' in self._data and (smooth_times >= 0 or time_list):
            # Build an observation array out of the real data first.
            obs = list(
                zip(*(self._data[x] for x in self._OBS_KEYS if x != 'times')))
            # Append extra observations if requested.
            if len(band_list):
                b_teles = band_telescopes if len(band_telescopes) == len(
                    band_list) else (
                        [band_telescopes[0] for x in band_list]
                        if len(band_telescopes) else ['' for x in band_list])
                b_systs = band_systems if len(band_systems) == len(
                    band_list) else ([band_systems[0]
                                      for x in band_list] if len(band_systems)
                                     else ['' for x in band_list])
                b_modes = band_modes if len(band_modes) == len(
                    band_list) else ([band_modes[0]
                                      for x in band_list] if len(band_modes)
                                     else ['' for x in band_list])
                b_insts = band_instruments if len(band_instruments) == len(
                    band_list) else (
                        [band_instruments[0] for x in band_list]
                        if len(band_instruments) else ['' for x in band_list])
                b_bsets = band_bandsets if len(band_bandsets) == len(
                    band_list) else ([band_bandsets[0]
                                      for x in band_list] if len(band_bandsets)
                                     else ['' for x in band_list])
                b_freqs = [None for x in band_list]
                b_u_freqs = ['' for x in band_list]
                b_zerops = [None for x in band_list]
                b_measures = [[] for x in band_list]
                obs.extend(
                    list(
                        zip(*(b_teles, b_systs, b_modes, b_insts, b_bsets,
                              band_list, b_freqs, b_u_freqs, b_zerops,
                              b_measures))))

            # Prune extra observations if they are duplicitous to existing.
            uniqueobs = []
            for o in obs:
                to = tuple(o)
                if to not in uniqueobs:
                    uniqueobs.append(to)

            # Preprend times to real observations list.
            minet, maxet = (extrapolate_time, extrapolate_time) if isinstance(
                extrapolate_time, (float, int)) else (
                    (tuple(extrapolate_time) if len(extrapolate_time) == 2 else
                     (extrapolate_time[0], extrapolate_time[0])))
            mint, maxt = (min(self._data['times']) - minet,
                          max(self._data['times']) + maxet)

            if time_unit is None:
                alltimes = time_list + [x for x in self._data['times']]
            elif time_unit == 'mjd':
                alltimes = [x - min(self._data['times']) for x in time_list
                            ] + [x for x in self._data['times']]
            elif time_unit == 'phase':
                if 'maxdate' not in self._data:
                    raise (prt.message('no_maxdate', name))
                max_mjd = astrotime(self._data['maxdate'].replace('/',
                                                                  '-')).mjd
                alltimes = [x + max_mjd for x in time_list
                            ] + [x for x in self._data['times']]
            else:
                raise ('Unknown `time_unit`.')
            if smooth_times >= 0:
                alltimes += list(np.linspace(mint, maxt, max(smooth_times, 2)))
            alltimes = list(sorted(set(alltimes)))

            # Create additional fake observations.
            currobslist = list(zip(*(self._data[x] for x in self._OBS_KEYS)))
            obslist = []
            for ti, t in enumerate(alltimes):
                new_per = np.round(100.0 * float(ti) / len(alltimes), 1)
                prt.message('construct_obs_array', [new_per], inline=True)
                for o in uniqueobs:
                    newobs = tuple([t] + list(o))
                    if newobs not in currobslist:
                        obslist.append(newobs)
            obslist.sort()

            # Save these fake observations under keys with `extra_` prefix.
            if obslist:
                for x, y in zip(self._OBS_KEYS, zip(*obslist)):
                    self._data['extra_' + x] = y

        for qkey in subtract_minimum_keys:
            if 'upperlimits' in self._data:
                new_vals = np.array(self._data[qkey])[np.array(
                    self._data['upperlimits']) != True]  # noqa E712
                if new_vals.size:
                    self._data['min_' + qkey] = min(new_vals)
                    self._data['max_' + qkey] = max(new_vals)
            minv = self._data['min_' + qkey]
            self._data[qkey] = [x - minv for x in self._data[qkey]]
            if 'extra_' + qkey in self._data:
                self._data['extra_' + qkey] = [
                    x - minv for x in self._data['extra_' + qkey]
                ]

        return True
コード例 #6
0
    def fetch(self, event_list, offline=False):
        """Fetch a list of events from the open catalogs."""
        dir_path = os.path.dirname(os.path.realpath(__file__))
        prt = self._printer

        levent_list = listify(event_list)
        events = [None for x in levent_list]

        catalogs = OrderedDict([
            (x, self._catalogs[x]) for x in self._catalogs
            if x not in self._excluded_catalogs])

        for ei, event in enumerate(levent_list):
            if not event:
                continue
            events[ei] = OrderedDict()
            path = ''
            # If the event name ends in .json, assume event is a path.
            if event.endswith('.json'):
                path = event
                events[ei]['name'] = event.replace('.json',
                                                   '').split('/')[-1]

            # If not (or the file doesn't exist), download from an open
            # catalog.
            if not path or not os.path.exists(path):
                names_paths = [
                    os.path.join(dir_path, 'cache', x +
                                 '.names.min.json') for x in catalogs]
                input_name = event.replace('.json', '')
                if offline:
                    prt.message('event_interp', [input_name])
                else:
                    prt.message('dling_aliases', [input_name])
                    for ci, catalog in enumerate(catalogs):
                        try:
                            response = get_url_file_handle(
                                catalogs[catalog]['json'] +
                                '/names.min.json',
                                timeout=10)
                        except Exception:
                            prt.message(
                                'cant_dl_names', [catalog], warning=True)
                        else:
                            with open_atomic(
                                    names_paths[ci], 'wb') as f:
                                shutil.copyfileobj(response, f)
                names = OrderedDict()
                for ci, catalog in enumerate(catalogs):
                    if os.path.exists(names_paths[ci]):
                        with open(names_paths[ci], 'r') as f:
                            names[catalog] = json.load(
                                f, object_pairs_hook=OrderedDict)
                    else:
                        prt.message('cant_read_names', [catalog],
                                    warning=True)
                        if offline:
                            prt.message('omit_offline')
                        continue

                    if input_name in names[catalog]:
                        events[ei]['name'] = input_name
                        events[ei]['catalog'] = catalog
                    else:
                        for name in names[catalog]:
                            if (input_name in names[catalog][name] or
                                    'SN' + input_name in
                                    names[catalog][name]):
                                events[ei]['name'] = name
                                events[ei]['catalog'] = catalog
                                break

                if not events[ei].get('name', None):
                    for ci, catalog in enumerate(catalogs):
                        namekeys = []
                        for name in names[catalog]:
                            namekeys.extend(names[catalog][name])
                        namekeys = list(sorted(set(namekeys)))
                        matches = get_close_matches(
                            event, namekeys, n=5, cutoff=0.8)
                        # matches = []
                        if len(matches) < 5 and is_number(event[0]):
                            prt.message('pef_ext_search')
                            snprefixes = set(('SN19', 'SN20'))
                            for name in names[catalog]:
                                ind = re.search("\d", name)
                                if ind and ind.start() > 0:
                                    snprefixes.add(name[:ind.start()])
                            snprefixes = list(sorted(snprefixes))
                            for prefix in snprefixes:
                                testname = prefix + event
                                new_matches = get_close_matches(
                                    testname, namekeys, cutoff=0.95,
                                    n=1)
                                if (len(new_matches) and
                                        new_matches[0] not in matches):
                                    matches.append(new_matches[0])
                                if len(matches) == 5:
                                    break
                        if len(matches):
                            if self._test:
                                response = matches[0]
                            else:
                                response = prt.prompt(
                                    'no_exact_match',
                                    kind='select',
                                    options=matches,
                                    none_string=(
                                        'None of the above, ' +
                                        ('skip this event.' if
                                         ci == len(catalogs) - 1
                                         else
                                         'try the next catalog.')))
                            if response:
                                for name in names[catalog]:
                                    if response in names[
                                            catalog][name]:
                                        events[ei]['name'] = name
                                        events[ei]['catalog'] = catalog
                                        break
                                if events[ei]['name']:
                                    break

                if not events[ei].get('name', None):
                    prt.message('no_event_by_name')
                    events[ei]['name'] = input_name
                    continue
                urlname = events[ei]['name'] + '.json'
                name_path = os.path.join(dir_path, 'cache', urlname)

                if offline:
                    prt.message('cached_event', [
                        events[ei]['name'], events[ei]['catalog']])
                else:
                    prt.message('dling_event', [
                        events[ei]['name'], events[ei]['catalog']])
                    try:
                        response = get_url_file_handle(
                            catalogs[events[ei]['catalog']][
                                'json'] + '/json/' + urlname,
                            timeout=10)
                    except Exception:
                        prt.message('cant_dl_event', [
                            events[ei]['name']], warning=True)
                    else:
                        with open_atomic(name_path, 'wb') as f:
                            shutil.copyfileobj(response, f)
                path = name_path

            if os.path.exists(path):
                events[ei]['path'] = path
                if self._open_in_browser:
                    webbrowser.open(
                        catalogs[events[ei]['catalog']]['web'] +
                        events[ei]['name'])
                with open(path, 'r') as f:
                    events[ei]['data'] = json.load(
                        f, object_pairs_hook=OrderedDict)
                prt.message('event_file', [path], wrapped=True)
            else:
                prt.message('no_data', [
                    events[ei]['name'],
                    '/'.join(catalogs.keys())])
                if offline:
                    prt.message('omit_offline')
                raise RuntimeError

        return events
コード例 #7
0
 def add_excluded_catalogs(self, catalogs):
     """Add catalog name(s) to list of catalogs that will be excluded."""
     if not isinstance(catalogs, list) or isinstance(
             catalogs, string_types):
         catalogs = listify(catalogs)
     self._excluded_catalogs.extend([x.upper() for x in catalogs])
コード例 #8
0
ファイル: converter.py プロジェクト: zodoctor/MOSFiT
    def assign_columns(self, cidict, flines):
        """Assign columns based on header."""
        used_cis = OrderedDict()
        akeys = list(self._critical_keys) + list(self._helpful_keys)
        dkeys = list(self._dep_keys)
        prt = self._printer

        for fi, fl in enumerate(flines):
            if not any([is_number(x) for x in fl]):
                # Try to associate column names with common header keys.
                conflict_keys = []
                conflict_cis = []
                for ci, col in enumerate(fl):
                    for key in self._header_keys:
                        if any([(x[0] if isinstance(x, tuple)
                                 else x) == col.lower()
                                for x in self._header_keys[key]]):
                            if key in cidict or ci in used_cis:
                                # There is a conflict, ask user.
                                conflict_keys.append(key)
                                conflict_cis.append(ci)
                            else:
                                ind = [
                                    (x[0] if isinstance(x, tuple) else x)
                                    for x in self._header_keys[key]].index(
                                        col.lower())
                                match = self._header_keys[key][ind]
                                cidict[key] = [match[-1], ci] if isinstance(
                                    match, tuple) else ci
                                used_cis[ci] = key
                            break

                for cki, ck in enumerate(conflict_keys):
                    if ck in cidict:
                        ci = cidict[ck]
                        del(cidict[ck])
                        del(used_cis[ci])
            else:
                self._first_data = fi
                break

        # Look for columns that are band names if no mag/counts/flux dens
        # column was found.
        if (not any([x in cidict for x in [
            PHOTOMETRY.MAGNITUDE, PHOTOMETRY.COUNT_RATE,
                PHOTOMETRY.FLUX_DENSITY]])):
            # Delete `E_MAGNITUDE` and `BAND` if they exist (we'll need to find
            # for each column).
            key = PHOTOMETRY.MAGNITUDE
            ekey = PHOTOMETRY.E_MAGNITUDE
            bkey = PHOTOMETRY.BAND
            if ekey in cidict:
                ci = cidict[ekey]
                del(cidict[used_cis[ci]])
                del(used_cis[ci])
            if bkey in cidict:
                ci = cidict[bkey]
                del(cidict[used_cis[ci]])
                del(used_cis[ci])
            for fi, fl in enumerate(flines):
                if not any([is_number(x) for x in fl]):
                    # Try to associate column names with common header keys.
                    for ci, col in enumerate(fl):
                        if ci in used_cis:
                            continue
                        if col in self._band_names:
                            cidict.setdefault(key, []).append(ci)
                            used_cis[ci] = key
                            cidict.setdefault(bkey, []).append(col)
                        elif col in self._emagstrs:
                            cidict.setdefault(ekey, []).append(ci)
                            used_cis[ci] = ekey

        # See which keys we collected. If we are missing any critical keys, ask
        # the user which column they are.

        # First ask the user if this data is in magnitudes or in counts.
        self._data_type = 1
        if (PHOTOMETRY.MAGNITUDE in cidict and
                PHOTOMETRY.COUNT_RATE not in cidict and
                PHOTOMETRY.FLUX_DENSITY not in cidict):
            self._data_type = 1
        elif (PHOTOMETRY.MAGNITUDE not in cidict and
              PHOTOMETRY.COUNT_RATE in cidict and
              PHOTOMETRY.FLUX_DENSITY not in cidict):
            self._data_type = 2
        elif (PHOTOMETRY.MAGNITUDE not in cidict and
              PHOTOMETRY.COUNT_RATE not in cidict and
              PHOTOMETRY.FLUX_DENSITY in cidict):
            self._data_type = 3
        else:
            self._data_type = prt.prompt(
                'counts_mags_fds', kind='option',
                options=['Magnitudes', 'Counts (per second)',
                         'Flux Densities (Jansky)'],
                none_string=None)
        if self._data_type in [1, 3]:
            akeys.remove(PHOTOMETRY.COUNT_RATE)
            akeys.remove(PHOTOMETRY.E_COUNT_RATE)
            akeys.remove(PHOTOMETRY.ZERO_POINT)
            if (PHOTOMETRY.MAGNITUDE in akeys and
                    PHOTOMETRY.E_MAGNITUDE in akeys):
                akeys.remove(PHOTOMETRY.E_MAGNITUDE)
                akeys.insert(
                    akeys.index(PHOTOMETRY.MAGNITUDE) + 1,
                    PHOTOMETRY.E_MAGNITUDE)
            if (PHOTOMETRY.E_LOWER_MAGNITUDE in cidict and
                    PHOTOMETRY.E_UPPER_MAGNITUDE in cidict):
                akeys.remove(PHOTOMETRY.E_MAGNITUDE)
            dkeys.remove(PHOTOMETRY.E_COUNT_RATE)
        if self._data_type in [2, 3]:
            akeys.remove(PHOTOMETRY.MAGNITUDE)
            akeys.remove(PHOTOMETRY.E_MAGNITUDE)
            dkeys.remove(PHOTOMETRY.E_MAGNITUDE)
        if self._data_type in [1, 2]:
            akeys.remove(PHOTOMETRY.FLUX_DENSITY)
            akeys.remove(PHOTOMETRY.E_FLUX_DENSITY)
            if (PHOTOMETRY.E_LOWER_FLUX_DENSITY in cidict and
                    PHOTOMETRY.E_UPPER_FLUX_DENSITY in cidict):
                akeys.remove(PHOTOMETRY.E_FLUX_DENSITY)
            dkeys.remove(PHOTOMETRY.E_FLUX_DENSITY)
            dkeys.remove(PHOTOMETRY.U_FLUX_DENSITY)

        columns = np.array(flines[self._first_data:]).T.tolist()
        colstrs = np.array([
            ', '.join(x[:5]) + ', ...' for x in columns])
        colinds = np.setdiff1d(np.arange(
            len(colstrs)), list([x[-1] if (
                isinstance(x, list) and not isinstance(
                    x, string_types)) else x for x in cidict.values()]))
        ignore = prt.message('ignore_column', prt=False)
        specify = prt.message('specify_column', prt=False)
        for key in akeys:
            selected_cols = [
                y for y in [a for b in [
                    listify(x) for x in list(cidict.values())] for a in b]
                if isinstance(y, (int, np.integer))]
            if key in cidict:
                continue
            if key in dkeys and self._use_mc:
                continue
            if key.type == KEY_TYPES.NUMERIC:
                lcolinds = [x for x in colinds
                            if any(is_number(y) for y in columns[x]) and
                            x not in selected_cols]
            elif key.type == KEY_TYPES.TIME:
                lcolinds = [x for x in colinds
                            if any(is_date(y) or is_number(y)
                                   for y in columns[x]) and
                            x not in selected_cols]
            elif key.type == KEY_TYPES.STRING:
                lcolinds = [x for x in colinds
                            if any(not is_number(y) for y in columns[x]) and
                            x not in selected_cols]
            else:
                lcolinds = [x for x in colinds if x not in selected_cols]
            select = False
            selects = []
            while select is False:
                mc = 1
                if key in self._mc_keys:
                    pkey = self._inflect.plural(key)
                    text = prt.message(
                        'one_per_line', [key, pkey, pkey],
                        prt=False)
                    mc = prt.prompt(
                        text, kind='option', message=False,
                        none_string=None,
                        options=[
                            'One `{}` per row'.format(key),
                            'Multiple `{}` per row'.format(pkey)])
                if mc == 1:
                    text = prt.message(
                        'no_matching_column', [key], prt=False)
                    ns = (
                        ignore if key in (
                            self._optional_keys + self._helpful_keys) else
                        specify if key in self._specify_keys
                        else None)
                    if len(colstrs[lcolinds]):
                        select = prt.prompt(
                            text, message=False,
                            kind='option', none_string=ns,
                            default=('j' if ns is None and
                                     len(colstrs[lcolinds]) > 1
                                     else None if ns is None else 'n'),
                            options=colstrs[lcolinds].tolist() + (
                                [('Multiple columns need to be joined.', 'j')]
                                if len(colstrs[lcolinds]) > 1 else []))
                    else:
                        select = None
                    if select == 'j':
                        select = None
                        jsel = None
                        selects.append('j')
                        while jsel != 'd' and len(lcolinds):
                            jsel = prt.prompt(
                                'join_which_columns', default='d',
                                kind='option', none_string=None,
                                options=colstrs[lcolinds].tolist() + [
                                    ('All columns to be joined '
                                     'have been selected.', 'd')
                                ])
                            if jsel != 'd':
                                selects.append(lcolinds[jsel - 1])
                                lcolinds = np.delete(lcolinds, jsel - 1)
                else:
                    self._use_mc = True
                    select = False
                    while select is not None:
                        text = prt.message(
                            'select_mc_column', [key], prt=False)
                        select = prt.prompt(
                            text, message=False,
                            kind='option', default='n',
                            none_string='No more `{}` columns.'.format(key),
                            options=colstrs[lcolinds].tolist())
                        if select is not None and select is not False:
                            selects.append(lcolinds[select - 1])
                            lcolinds = np.delete(lcolinds, select - 1)
                        else:
                            break
                        for dk in dkeys:
                            dksel = None
                            while dksel is None:
                                text = prt.message(
                                    'select_dep_column', [dk, key], prt=False)
                                sk = dk in self._specify_keys
                                if not sk:
                                    dksel = prt.prompt(
                                        text, message=False,
                                        kind='option', none_string=None,
                                        options=colstrs[lcolinds].tolist())
                                    if dksel is not None:
                                        selects.append(lcolinds[dksel - 1])
                                        lcolinds = np.delete(
                                            lcolinds, dksel - 1)
                                else:
                                    spectext = prt.message(
                                        'specify_mc_value', [dk, key],
                                        prt=False)
                                    val = ''
                                    while val.strip() is '':
                                        val = prt.prompt(
                                            spectext, message=False,
                                            kind='string')
                                    selects.append(val)
                                    break

            if select is not None:
                iselect = int(select)
                cidict[key] = lcolinds[iselect - 1]
                colinds = np.delete(colinds, np.argwhere(
                    colinds == lcolinds[iselect - 1]))
            elif len(selects):
                if selects[0] == 'j':
                    cidict[key] = selects
                else:
                    kdkeys = [key] + dkeys
                    allk = list(OrderedDict.fromkeys(kdkeys).keys())
                    for ki, k in enumerate(allk):
                        cidict[k] = [
                            colinds[s - 1] if isinstance(s, (
                                int, np.integer)) else s
                            for s in selects[ki::len(allk)]]
                    for s in selects:
                        if not isinstance(s, (int, np.integer)):
                            continue
                        colinds = np.delete(colinds, np.argwhere(
                            colinds == s - 1))
            elif key in self._specify_keys:
                msg = ('specify_value_blank' if key in self._helpful_keys else
                       'specify_value')
                text = prt.message(msg, [key], prt=False)
                cidict[key] = prt.prompt(
                    text, message=False, kind='string', allow_blank=(
                        key in self._helpful_keys))

        self._zp = ''
        if self._data_type == 2 and PHOTOMETRY.ZERO_POINT not in cidict:
            while not is_number(self._zp):
                self._zp = prt.prompt('zeropoint', kind='string')

        self._ufd = None
        if self._data_type == 3 and PHOTOMETRY.U_FLUX_DENSITY not in cidict:
            while ((self._ufd.lower() if self._ufd is not None else None)
                   not in ['µjy', 'mjy', 'jy', 'microjy', 'millijy', 'jy',
                           'microjansky', 'millijansky', 'jansky', '']):
                self._ufd = prt.prompt('u_flux_density', kind='string')

        self._system = None
        if self._data_type == 1 and PHOTOMETRY.SYSTEM not in cidict:
            systems = ['AB', 'Vega']
            self._system = prt.prompt(
                'system', kind='option', options=systems,
                none_string='Use default for all bands.',
                default='n')
            if self._system is not None:
                self._system = systems[int(self._system) - 1]

        if (PHOTOMETRY.INSTRUMENT not in cidict and
                PHOTOMETRY.TELESCOPE not in cidict):
            prt.message('instrument_recommended', warning=True)
コード例 #9
0
ファイル: model.py プロジェクト: slowdivePTG/MOSFiT
    def adjust_fixed_parameters(self, variance_for_each=[], output={}):
        """Create free parameters that depend on loaded data."""
        unique_band_indices = list(
            sorted(set(output.get('all_band_indices', []))))
        needs_general_variance = any(
            np.array(output.get('all_band_indices', [])) < 0)

        new_call_stack = OrderedDict()
        for task in self._call_stack:
            cur_task = self._call_stack[task]
            vfe = listify(variance_for_each)
            if task == 'variance' and 'band' in vfe:
                vfi = vfe.index('band') + 1
                mwfd = float(vfe[vfi]) if (vfi < len(vfe) and is_number(
                    vfe[vfi])) else self.MIN_WAVE_FRAC_DIFF
                # Find photometry in call stack.
                ptask = None
                for ptask in self._call_stack:
                    if ptask == 'photometry':
                        awaves = self._modules[ptask].average_wavelengths(
                            unique_band_indices)
                        abands = self._modules[ptask].bands(
                            unique_band_indices)
                        band_pairs = list(sorted(zip(awaves, abands)))
                        break
                owav = 0.0
                variance_bands = []
                for (awav, band) in band_pairs:
                    wave_frac_diff = abs(awav - owav) / (awav + owav)
                    if wave_frac_diff < mwfd:
                        continue
                    new_task_name = '-'.join([task, 'band', band])
                    if new_task_name in self._call_stack:
                        continue
                    new_task = deepcopy(cur_task)
                    new_call_stack[new_task_name] = new_task
                    if 'latex' in new_task:
                        new_task['latex'] += '_{\\rm ' + band + '}'
                    new_call_stack[new_task_name] = new_task
                    self._modules[new_task_name] = self._load_task_module(
                        new_task_name, call_stack=new_call_stack)
                    owav = awav
                    variance_bands.append([awav, band])
                if needs_general_variance:
                    new_call_stack[task] = deepcopy(cur_task)
                if self._pool.is_master():
                    self._printer.message(
                        'anchoring_variances',
                        [', '.join([x[1] for x in variance_bands])],
                        wrapped=True)
                self._modules[ptask].set_variance_bands(variance_bands)
            else:
                new_call_stack[task] = deepcopy(cur_task)
            # Fixed any variables to be fixed if any conditional inputs are
            # fixed by the data.
            # if any([listify(x)[-1] == 'conditional'
            #         for x in cur_task.get('inputs', [])]):
        self._call_stack = new_call_stack

        for task in reversed(self._call_stack):
            cur_task = self._call_stack[task]
            for inp in cur_task.get('inputs', []):
                other = listify(inp)[0]
                if (cur_task['kind'] == 'parameter'
                        and output.get(other, None) is not None):
                    if (not self._modules[other]._fixed
                            or self._modules[other]._fixed_by_user):
                        self._modules[task]._fixed = True
                    self._modules[task]._derived_keys = list(
                        set(self._modules[task]._derived_keys + [task]))
コード例 #10
0
    def set_data(self,
                 all_data,
                 req_key_values={},
                 subtract_minimum_keys=[],
                 smooth_times=-1,
                 extrapolate_time=0.0,
                 limit_fitting_mjds=False,
                 exclude_bands=[],
                 exclude_instruments=[],
                 exclude_systems=[],
                 exclude_sources=[],
                 exclude_kinds=[],
                 band_list=[],
                 band_telescopes=[],
                 band_systems=[],
                 band_instruments=[],
                 band_modes=[],
                 band_bandsets=[]):
        """Set transient data."""
        prt = self._printer

        self._all_data = all_data
        self._data = OrderedDict()
        if not self._all_data:
            return
        name = list(self._all_data.keys())[0]
        self._data['name'] = name
        numeric_keys = set()

        ex_kinds = [self._EX_REPS.get(
            x.lower(), x.lower()) for x in exclude_kinds]

        # Construct some source dictionaries for exclusion rules
        src_dict = OrderedDict()
        sources = self._all_data[name].get('sources', [])
        for src in sources:
            if SOURCE.BIBCODE in src:
                src_dict[src[SOURCE.ALIAS]] = src[SOURCE.BIBCODE]
            if SOURCE.ARXIVID in src:
                src_dict[src[SOURCE.ALIAS]] = src[SOURCE.ARXIVID]
            if SOURCE.NAME in src:
                src_dict[src[SOURCE.ALIAS]] = src[SOURCE.NAME]

        for key in self._keys:
            subkeys = self._keys[key]
            req_subkeys = [
                x for x in subkeys
                if not isinstance(subkeys, dict) or 'required' in listify(
                    subkeys[x])
            ]
            num_subkeys = [
                x for x in subkeys if 'numeric' in listify(subkeys[x])
            ]
            boo_subkeys = [
                x for x in subkeys if 'boolean' in listify(subkeys[x])
            ]
            exc_subkeys = [
                x for x in subkeys if 'exclude' in listify(subkeys[x])
            ]

            if (key not in self._all_data[name] and not
                    self._model.is_parameter_fixed_by_user(key)):
                if subkeys.get('value', None) == 'recommended':
                    self._unset_recommended_keys.add(key)
                continue

            subdata = self._all_data[name].get(key)

            if subdata is None:
                continue

            # Only include data that contains all subkeys
            for entry in subdata:
                if any([x not in entry for x in req_subkeys]):
                    continue
                if any([x in entry for x in exc_subkeys]):
                    continue
                if any([
                        x in entry and ((isinstance(entry[x], list) and any([
                            not is_number(y) or np.isnan(float(y))
                            for y in entry[x]
                        ])) or not isinstance(entry[x], list) and (
                            not is_number(entry[x]) or np.isnan(
                                float(entry[x]))))
                        for x in num_subkeys
                ]):
                    continue

                skip_key = False
                if 'frequency' not in entry:
                    for qkey in req_key_values:
                        if qkey in entry and entry[qkey] != '':
                            if entry[qkey] not in req_key_values[qkey]:
                                skip_key = True
                            break

                if key == 'photometry':
                    if ex_kinds is not False:
                        if 'radio' in ex_kinds:
                            if ('fluxdensity' in entry and
                                'magnitude' not in entry and
                                    'countrate' not in entry):
                                continue
                        if 'x-ray' in ex_kinds:
                            if (('countrate' in entry or
                                 'unabsorbedflux' in entry or
                                 'flux' in entry) and
                                'magnitude' not in entry and
                                    'fluxdensity' not in entry):
                                continue

                    skip_entry = False

                    for x in subkeys:
                        if limit_fitting_mjds is not False and x == 'time':
                            val = np.mean([
                                float(x) for x in listify(
                                    entry.get(x, None))])
                            if (val < limit_fitting_mjds[0] or
                                    val > limit_fitting_mjds[1]):
                                skip_entry = True
                                break
                        if exclude_bands is not False and x == 'band':
                            if (entry.get(x, '') in exclude_bands and
                                (not exclude_instruments or entry.get(
                                    'instrument', '') in
                                 exclude_instruments) and (
                                     not exclude_systems or entry.get(
                                    'system', '') in exclude_systems)):
                                skip_entry = True
                                break
                        if (exclude_instruments is not False and
                                x == 'instrument'):
                            if (entry.get(x, '') in exclude_instruments and
                                (not exclude_bands or
                                 entry.get('band', '') in
                                 exclude_bands) and (
                                     not exclude_systems or entry.get(
                                    'system', '') in exclude_systems)):
                                skip_entry = True
                                break
                        if (exclude_systems is not False and
                                x == 'system'):
                            if (entry.get(x, '') in exclude_systems and
                                (not exclude_bands or
                                 entry.get('band', '') in
                                 exclude_bands) and (
                                     not exclude_instruments or entry.get(
                                    'instrument', '') in exclude_instruments)):
                                skip_entry = True
                                break
                        if (exclude_sources is not False and
                                x == 'source'):
                            val = entry.get(x, '')
                            if (any([x in exclude_sources
                                     for x in val.split(',')]) or
                                any([src_dict.get(x, '') in exclude_sources
                                     for x in val.split(',')])):
                                skip_entry = True
                                break
                    if skip_entry:
                        continue

                    if ((('magnitude' in entry) != ('band' in entry)) or
                        (('fluxdensity' in entry) != ('frequency' in entry)) or
                        (('countrate' in entry) and
                         ('magnitude' not in entry) and
                         ('instrument' not in entry))):
                        continue

                for x in subkeys:
                    falseval = (
                        False if x in boo_subkeys else None if
                        x in num_subkeys else '')
                    if x == 'value':
                        if not skip_key:
                            self._data[key] = entry.get(x, falseval)
                    else:
                        plural = self._model.plural(x)
                        val = entry.get(x, falseval)
                        if x in num_subkeys:
                            val = None if val is None else np.mean([
                                float(x) for x in listify(val)])
                        if not skip_key:
                            self._data.setdefault(plural, []).append(val)
                            if x in num_subkeys:
                                numeric_keys.add(plural)
                        else:
                            self._data.setdefault(
                                'unmatched_' + plural, []).append(val)

        if 'times' not in self._data or not any([x in self._data for x in [
                'magnitudes', 'frequencies', 'countrates']]):
            prt.message('no_fittable_data', [name])
            return False

        for key in list(self._data.keys()):
            if isinstance(self._data[key], list):
                self._data[key] = np.array(self._data[key])
                if key not in numeric_keys:
                    continue
                num_values = [
                    x for x in self._data[key] if isinstance(x, float)
                ]
                if len(num_values):
                    self._data['min_' + key] = min(num_values)
                    self._data['max_' + key] = max(num_values)
            else:
                if is_number(self._data[key]):
                    self._data[key] = float(self._data[key])
                    self._data_determined_parameters.append(key)

        if 'times' in self._data and smooth_times >= 0:
            obs = list(
                zip(*(self._data['telescopes'], self._data['systems'],
                      self._data['modes'], self._data['instruments'],
                      self._data['bandsets'], self._data['bands'], self._data[
                          'frequencies'], self._data['u_frequencies'])))
            if len(band_list):
                b_teles = band_telescopes if len(band_telescopes) == len(
                    band_list) else ([band_telescopes[0] for x in band_list]
                                     if len(band_telescopes) else
                                     ['' for x in band_list])
                b_systs = band_systems if len(band_systems) == len(
                    band_list) else ([band_systems[0] for x in band_list]
                                     if len(band_systems) else
                                     ['' for x in band_list])
                b_modes = band_modes if len(band_modes) == len(
                    band_list) else ([band_modes[0] for x in band_list]
                                     if len(band_modes) else
                                     ['' for x in band_list])
                b_insts = band_instruments if len(band_instruments) == len(
                    band_list) else ([band_instruments[0] for x in band_list]
                                     if len(band_instruments) else
                                     ['' for x in band_list])
                b_bsets = band_bandsets if len(band_bandsets) == len(
                    band_list) else ([band_bandsets[0] for x in band_list]
                                     if len(band_bandsets) else
                                     ['' for x in band_list])
                b_freqs = [None for x in band_list]
                b_u_freqs = ['' for x in band_list]
                obs.extend(
                    list(
                        zip(*(b_teles, b_systs, b_modes, b_insts, b_bsets,
                              band_list, b_freqs, b_u_freqs))))

            uniqueobs = []
            for o in obs:
                to = tuple(o)
                if to not in uniqueobs:
                    uniqueobs.append(to)

            minet, maxet = (extrapolate_time, extrapolate_time) if isinstance(
                extrapolate_time, (float, int)) else (
                    (tuple(extrapolate_time) if len(extrapolate_time) == 2 else
                     (extrapolate_time[0], extrapolate_time[0])))
            mint, maxt = (min(self._data['times']) - minet,
                          max(self._data['times']) + maxet)
            alltimes = list(
                sorted(
                    set([x for x in self._data['times']] + list(
                        np.linspace(mint, maxt, max(smooth_times, 2))))))
            currobslist = list(
                zip(*(
                    self._data['times'], self._data['telescopes'],
                    self._data['systems'], self._data['modes'],
                    self._data['instruments'], self._data['bandsets'],
                    self._data['bands'], self._data['frequencies'],
                    self._data['u_frequencies'])))

            obslist = []
            for ti, t in enumerate(alltimes):
                new_per = np.round(100.0 * float(ti) / len(alltimes), 1)
                prt.message('construct_obs_array', [new_per], inline=True)
                for o in uniqueobs:
                    newobs = tuple([t] + list(o))
                    if newobs not in currobslist:
                        obslist.append(newobs)

            obslist.sort()

            if len(obslist):
                (self._data['extra_times'], self._data['extra_telescopes'],
                 self._data['extra_systems'], self._data['extra_modes'],
                 self._data['extra_instruments'], self._data['extra_bandsets'],
                 self._data['extra_bands'], self._data['extra_frequencies'],
                 self._data['extra_u_frequencies']) = zip(*obslist)

        for qkey in subtract_minimum_keys:
            if 'upperlimits' in self._data:
                new_vals = np.array(self._data[qkey])[
                    np.array(self._data['upperlimits']) != True]  # noqa E712
                if len(new_vals):
                    self._data['min_' + qkey] = min(new_vals)
                    self._data['max_' + qkey] = max(new_vals)
            minv = self._data['min_' + qkey]
            self._data[qkey] = [x - minv for x in self._data[qkey]]
            if 'extra_' + qkey in self._data:
                self._data['extra_' + qkey] = [
                    x - minv for x in self._data['extra_' + qkey]
                ]

        return True