Ejemplo n.º 1
0
def main():
    """Run MOSFiT."""
    prt = Printer(wrap_length=100,
                  quiet=False,
                  language='en',
                  exit_on_prompt=False)

    parser = get_parser(only='language')
    args, remaining = parser.parse_known_args()

    if args.language == 'en':
        loc = locale.getlocale()
        if loc[0]:
            args.language = loc[0].split('_')[0]

    if args.language != 'en':
        try:
            from googletrans.constants import LANGUAGES
        except Exception:
            raise RuntimeError('`--language` requires `googletrans` package, '
                               'install with `pip install googletrans`.')

        if args.language == 'select' or args.language not in LANGUAGES:
            languages = list(
                sorted([
                    LANGUAGES[x].title().replace('_', ' ') + ' (' + x + ')'
                    for x in LANGUAGES
                ]))
            sel = prt.prompt('Select a language:',
                             kind='select',
                             options=languages,
                             message=False)
            args.language = sel.split('(')[-1].strip(')')

    prt = Printer(language=args.language)

    language = args.language

    parser = get_parser(printer=prt)
    args = parser.parse_args()

    args.language = language

    prt = Printer(wrap_length=100,
                  quiet=args.quiet,
                  language=args.language,
                  exit_on_prompt=args.exit_on_prompt)

    if args.version:
        print('MOSFiT v{}'.format(__version__))
        return

    dir_path = os.path.dirname(os.path.realpath(__file__))

    if args.speak:
        speak('Mosfit', args.speak)

    args.start_time = time.time()

    if args.limiting_magnitude == []:
        args.limiting_magnitude = 20.0

    args.return_fits = False

    if (isinstance(args.extrapolate_time, list)
            and len(args.extrapolate_time) == 0):
        args.extrapolate_time = 100.0

    if len(args.band_list) and args.smooth_times == -1:
        prt.message('enabling_s')
        args.smooth_times = 0

    args.method = 'nester' if args.method.lower() in [
        'nest', 'nested', 'nested_sampler', 'nester'
    ] else 'ensembler'

    if is_master():
        if args.method == 'nester':
            unused_args = [[args.burn, '-b'], [args.post_burn, '-p'],
                           [args.frack_step, '-f'], [args.num_temps, '-T'],
                           [args.run_until_uncorrelated, '-U'],
                           [args.draw_above_likelihood, '-d'],
                           [args.gibbs, '-g'], [args.save_full_chain, '-c'],
                           [args.maximum_memory, '-M']]
            for ua in unused_args:
                if ua[0] is not None:
                    prt.message('argument_not_used',
                                reps=[ua[1], '-D nester'],
                                warning=True)

    if args.method == 'nester':
        if args.run_until_converged and args.iterations >= 0:
            raise ValueError(prt.text('R_i_mutually_exclusive'))
        if args.walker_paths is not None:
            raise ValueError(prt.text('w_nester_mutually_exclusive'))

    if args.generative:
        if args.iterations > 0:
            prt.message('generative_supercedes', warning=True)
        args.iterations = 0

    no_events = False
    if args.iterations == -1:
        if len(args.events) == 0:
            no_events = True
            args.iterations = 0
        else:
            args.iterations = 5000

    if len(args.date_list):
        if no_events:
            prt.message('no_dates_gen', warning=True)
        else:
            args.time_list = [
                str(astrotime(x.replace('/', '-')).mjd) for x in args.date_list
            ]
            args.time_unit = 'mjd'

    if len(args.mjd_list):
        if no_events:
            prt.message('no_dates_gen', warning=True)
        else:
            args.time_list = args.mjd_list
            args.time_unit = 'mjd'

    if len(args.jd_list):
        if no_events:
            prt.message('no_dates_gen', warning=True)
        else:
            args.time_list = [
                str(astrotime(float(x), format='jd').mjd) for x in args.jd_list
            ]
            args.time_unit = 'mjd'

    if len(args.phase_list):
        if no_events:
            prt.message('no_dates_gen', warning=True)
        else:
            args.time_list = args.phase_list
            args.time_unit = 'phase'

    if len(args.time_list):
        if any([any([y in x]) for y in ['-', '/'] for x in args.time_list]):
            try:
                args.time_list = [
                    astrotime(x.replace('/', '-')).mjd for x in args.time_list
                ]
            except ValueError:
                if len(args.time_list) == 1 and isinstance(
                        args.time_list[0], string_types):
                    args.time_list = args.time_list[0].split()
                args.time_list = [float(x) for x in args.time_list]
                args.time_unit = 'phase'
        else:
            if any(['+' in x for x in args.time_list]):
                args.time_unit = 'phase'
            args.time_list = [float(x) for x in args.time_list]

        if min(args.time_list) > 2400000:
            prt.message('assuming_jd')
            args.time_list = [x - 2400000.5 for x in args.time_list]
            args.time_unit = 'mjd'
        elif min(args.time_list) > 50000:
            prt.message('assuming_mjd')
            args.time_unit = 'mjd'
        args.time_unit = None

    if args.burn is None and args.post_burn is None:
        args.burn = int(np.floor(args.iterations / 2))

    if args.frack_step == 0:
        args.fracking = False

    if (args.run_until_uncorrelated is not None and args.run_until_converged):
        raise ValueError(
            '`-R` and `-U` options are incompatible, please use one or the '
            'other.')
    if args.run_until_uncorrelated is not None:
        args.convergence_type = 'acor'
        args.convergence_criteria = args.run_until_uncorrelated
    elif args.run_until_converged:
        if args.method == 'ensembler':
            args.convergence_type = 'psrf'
            args.convergence_criteria = (1.1
                                         if args.run_until_converged is True
                                         else args.run_until_converged)
        else:
            args.convergence_type = 'dlogz'

    if args.method == 'nester':
        args.convergence_criteria = (0.02 if args.run_until_converged is True
                                     else args.run_until_converged)

    if is_master():
        # Get hash of ourselves
        mosfit_hash = get_mosfit_hash()

        # Print our amazing ASCII logo.
        if not args.quiet:
            with codecs.open(os.path.join(dir_path, 'logo.txt'), 'r',
                             'utf-8') as f:
                logo = f.read()
                firstline = logo.split('\n')[0]
                # if isinstance(firstline, bytes):
                #     firstline = firstline.decode('utf-8')
                width = len(normalize('NFC', firstline))
            prt.prt(logo, colorify=True)
            prt.message(
                'byline',
                reps=[__version__, mosfit_hash, __author__, __contributors__],
                center=True,
                colorify=True,
                width=width,
                wrapped=False)

        # Get/set upload token
        upload_token = ''
        get_token_from_user = False
        if args.set_upload_token:
            if args.set_upload_token is not True:
                upload_token = args.set_upload_token
            get_token_from_user = True

        upload_token_path = os.path.join(dir_path, 'cache', 'dropbox.token')

        # Perform a few checks on upload before running (to keep size
        # manageable)
        if args.upload and not args.test and args.smooth_times > 100:
            response = prt.prompt('ul_warning_smooth')
            if response:
                args.upload = False
            else:
                sys.exit()

        if (args.upload and not args.test and args.num_walkers is not None
                and args.num_walkers < 100):
            response = prt.prompt('ul_warning_few_walkers')
            if response:
                args.upload = False
            else:
                sys.exit()

        if (args.upload and not args.test and args.num_walkers
                and args.num_walkers * args.num_temps > 500):
            response = prt.prompt('ul_warning_too_many_walkers')
            if response:
                args.upload = False
            else:
                sys.exit()

        if args.upload:
            if not os.path.isfile(upload_token_path):
                get_token_from_user = True
            else:
                with open(upload_token_path, 'r') as f:
                    upload_token = f.read().splitlines()
                    if len(upload_token) != 1:
                        get_token_from_user = True
                    elif len(upload_token[0]) != 64:
                        get_token_from_user = True
                    else:
                        upload_token = upload_token[0]

        if get_token_from_user:
            if args.test:
                upload_token = ('1234567890abcdefghijklmnopqrstuvwxyz'
                                '1234567890abcdefghijklmnopqr')
            while len(upload_token) != 64:
                prt.message('no_ul_token', ['https://sne.space/mosfit/'],
                            wrapped=True)
                upload_token = prt.prompt('paste_token', kind='string')
                if len(upload_token) != 64:
                    prt.prt(
                        'Error: Token must be exactly 64 characters in '
                        'length.',
                        wrapped=True)
                    continue
                break
            with open_atomic(upload_token_path, 'w') as f:
                f.write(upload_token)

        if args.upload:
            prt.prt("Upload flag set, will upload results after completion.",
                    wrapped=True)
            prt.prt("Dropbox token: " + upload_token, wrapped=True)

        args.upload_token = upload_token

        if no_events:
            prt.message('iterations_0', wrapped=True)

        # Create the user directory structure, if it doesn't already exist.
        if args.copy:
            prt.message('copying')
            fc = False
            if args.force_copy:
                fc = prt.prompt('force_copy')
            if not os.path.exists('jupyter'):
                os.mkdir(os.path.join('jupyter'))
            if not os.path.isfile(os.path.join('jupyter',
                                               'mosfit.ipynb')) or fc:
                shutil.copy(
                    os.path.join(dir_path, 'jupyter', 'mosfit.ipynb'),
                    os.path.join(os.getcwd(), 'jupyter', 'mosfit.ipynb'))

            if not os.path.exists('modules'):
                os.mkdir(os.path.join('modules'))
            module_dirs = next(os.walk(os.path.join(dir_path, 'modules')))[1]
            for mdir in module_dirs:
                if mdir.startswith('__'):
                    continue
                full_mdir = os.path.join(dir_path, 'modules', mdir)
                copy_path = os.path.join(full_mdir, '.copy')
                to_copy = []
                if os.path.isfile(copy_path):
                    to_copy = list(
                        filter(None,
                               open(copy_path, 'r').read().split()))

                mdir_path = os.path.join('modules', mdir)
                if not os.path.exists(mdir_path):
                    os.mkdir(mdir_path)
                for tc in to_copy:
                    tc_path = os.path.join(full_mdir, tc)
                    if os.path.isfile(tc_path):
                        shutil.copy(tc_path, os.path.join(mdir_path, tc))
                    elif os.path.isdir(tc_path) and not os.path.exists(
                            os.path.join(mdir_path, tc)):
                        os.mkdir(os.path.join(mdir_path, tc))
                readme_path = os.path.join(mdir_path, 'README')
                if not os.path.exists(readme_path):
                    txt = prt.message('readme-modules', [
                        os.path.join(dir_path, 'modules', 'mdir'),
                        os.path.join(dir_path, 'modules')
                    ],
                                      prt=False)
                    open(readme_path, 'w').write(txt)

            if not os.path.exists('models'):
                os.mkdir(os.path.join('models'))
            model_dirs = next(os.walk(os.path.join(dir_path, 'models')))[1]
            for mdir in model_dirs:
                if mdir.startswith('__'):
                    continue
                mdir_path = os.path.join('models', mdir)
                if not os.path.exists(mdir_path):
                    os.mkdir(mdir_path)
                model_files = next(
                    os.walk(os.path.join(dir_path, 'models', mdir)))[2]
                readme_path = os.path.join(mdir_path, 'README')
                if not os.path.exists(readme_path):
                    txt = prt.message('readme-models', [
                        os.path.join(dir_path, 'models', mdir),
                        os.path.join(dir_path, 'models')
                    ],
                                      prt=False)
                    with open(readme_path, 'w') as f:
                        f.write(txt)
                for mfil in model_files:
                    if 'parameters.json' not in mfil:
                        continue
                    fil_path = os.path.join(mdir_path, mfil)
                    if os.path.isfile(fil_path) and not fc:
                        continue
                    shutil.copy(os.path.join(dir_path, 'models', mdir, mfil),
                                os.path.join(fil_path))

    # Set some default values that we checked above.
    if args.frack_step == 0:
        args.fracking = False
    elif args.frack_step is None:
        args.frack_step = 50
    if args.burn is None and args.post_burn is None:
        args.burn = int(np.floor(args.iterations / 2))
    if args.draw_above_likelihood is None:
        args.draw_above_likelihood = False
    if args.maximum_memory is None:
        args.maximum_memory = np.inf
    if args.gibbs is None:
        args.gibbs = False
    if args.save_full_chain is None:
        args.save_full_chain = False
    if args.num_temps is None:
        args.num_temps = 1
    if args.walker_paths is None:
        args.walker_paths = []

    # Then, fit the listed events with the listed models.
    fitargs = vars(args)
    Fitter(**fitargs).fit_events(**fitargs)
Ejemplo n.º 2
0
    def load_bands(self, band_indices):
        """Load band files."""
        prt = self._printer

        if self._pool.is_master():
            vo_tabs = OrderedDict()

        per = 0.0
        bc = 0
        band_set = set(band_indices)
        for i, band in enumerate(self._unique_bands):
            if len(band_indices) and i not in band_set:
                continue
            if self._pool.is_master():
                new_per = np.round(100.0 * float(bc) / len(band_set))
                if new_per > per:
                    per = new_per
                    prt.message('loading_bands', [per], inline=True)
                systems = ['AB']
                zps = [0.0]
                path = None
                if 'SVO' in band:
                    photsystem = self._band_systs[i]
                    if photsystem in syst_syns:
                        photsystem = syst_syns[photsystem]
                    if photsystem not in systems:
                        systems.append(photsystem)
                    zpfluxes = []
                    for sys in systems:
                        svopath = band['SVO'] + '/' + sys
                        path = os.path.join(self._dir_path, 'filters',
                                            svopath.replace('/', '_') + '.dat')

                        xml_path = os.path.join(
                            self._dir_path, 'filters',
                            svopath.replace('/', '_') + '.xml')
                        if not os.path.exists(xml_path):
                            prt.message('dl_svo', [svopath], inline=True)
                            try:
                                response = get_url_file_handle(
                                    'http://svo2.cab.inta-csic.es'
                                    '/svo/theory/fps3/'
                                    'fps.php?PhotCalID=' + svopath,
                                    timeout=10)
                            except Exception:
                                prt.message('cant_dl_svo', warning=True)
                            else:
                                with open_atomic(xml_path, 'wb') as f:
                                    shutil.copyfileobj(response, f)

                        if os.path.exists(xml_path):
                            already_written = svopath in vo_tabs
                            if not already_written:
                                vo_tabs[svopath] = voparse(xml_path)
                            vo_tab = vo_tabs[svopath]
                            # need to account for zeropoint type

                            for resource in vo_tab.resources:
                                if len(resource.params) == 0:
                                    params = vo_tab.get_first_table().params
                                else:
                                    params = resource.params

                            oldzplen = len(zps)
                            for param in params:
                                if param.name == 'ZeroPoint':
                                    zpfluxes.append(param.value)
                                    if sys != 'AB':
                                        # 0th element is AB flux
                                        zps.append(2.5 * np.log10(
                                            zpfluxes[0] / zpfluxes[-1]))
                                else:
                                    continue
                            if sys != 'AB' and len(zps) == oldzplen:
                                raise RuntimeError(
                                    'ZeroPoint not found in XML.')

                            if not already_written:
                                vo_dat = vo_tab.get_first_table().array
                                bi = max(
                                    next((i for i, x in enumerate(vo_dat)
                                          if x[1]), 0) - 1, 0)
                                ei = -max(
                                    next((i
                                          for i, x in enumerate(
                                              reversed(vo_dat))
                                          if x[1]), 0) - 1, 0)
                                vo_dat = vo_dat[bi:ei if ei else len(vo_dat)]
                                vo_string = '\n'.join([
                                    ' '.join([str(y) for y in x])
                                    for x in vo_dat
                                ])
                                with open_atomic(path, 'w') as f:
                                    f.write(vo_string)
                        else:
                            raise RuntimeError(
                                prt.string('cant_read_svo'))
                    self._unique_bands[i]['origin'] = band['SVO']
                elif all(x in band for x in [
                        'min_wavelength', 'max_wavelength',
                        'delta_wavelength']):
                    nbins = int(np.round((
                        band['max_wavelength'] -
                        band['min_wavelength']) / band[
                            'delta_wavelength'])) + 1
                    rows = np.array(
                        [np.linspace(
                            band['min_wavelength'], band['max_wavelength'],
                            nbins), np.full(nbins, 1.0)]).T.tolist()
                    self._unique_bands[i]['origin'] = 'generated'
                elif 'path' in band:
                    self._unique_bands[i]['origin'] = band['path']
                    path = band['path']
                else:
                    raise RuntimeError(prt.text('bad_filter_rule'))

                if path:
                    with open(os.path.join(
                            self._dir_path, 'filters', path), 'r') as f:
                        rows = []
                        for row in csv.reader(
                                f, delimiter=' ', skipinitialspace=True):
                            rows.append([float(x) for x in row[:2]])
                for rank in range(1, self._pool.size + 1):
                    self._pool.comm.send(rows, dest=rank, tag=3)
                    self._pool.comm.send(zps, dest=rank, tag=4)
            else:
                rows = self._pool.comm.recv(source=0, tag=3)
                zps = self._pool.comm.recv(source=0, tag=4)

            xvals, yvals = list(
                map(list, zip(*rows)))
            xvals = np.array(xvals)
            yvals = np.array(yvals)

            if '{0}'.format(self._band_yunits[i]) == 'cm2':
                xscale = (c.h * c.c /
                          u.Angstrom).cgs.value / self._band_xu[i]
                self._band_energies[
                    i], self._band_areas[i] = xvals, yvals / xvals
                self._band_wavelengths[i] = xscale / self._band_energies[i]
                self._average_wavelengths[i] = np.trapz([
                    x * y
                    for x, y in zip(
                        self._band_areas[i], self._band_wavelengths[i])
                ], self._band_wavelengths[i]) / np.trapz(
                    self._band_areas[i], self._band_wavelengths[i])
            else:
                self._band_wavelengths[
                    i], self._transmissions[i] = xvals, yvals
                self._filter_integrals[i] = self.FLUX_STD * np.trapz(
                    np.array(self._transmissions[i]) /
                    np.array(self._band_wavelengths[i]) ** 2,
                    self._band_wavelengths[i])
                self._average_wavelengths[i] = np.trapz([
                    x * y
                    for x, y in zip(
                        self._transmissions[i], self._band_wavelengths[i])
                ], self._band_wavelengths[i]) / np.trapz(
                    self._transmissions[i], self._band_wavelengths[i])

                if 'offset' in band:
                    self._band_offsets[i] = band['offset']
                elif 'SVO' in band:
                    self._band_offsets[i] = zps[-1]

                # Do some sanity checks.
                if (self._band_offsets[i] != 0.0 and
                        self._band_systs[i] == 'AB'):
                    raise RuntimeError(
                        'Filters in AB system should always have offset = '
                        '0.0, not the case for `{}`'.format(
                            self._band_names[i]))

            self._min_waves[i] = min(self._band_wavelengths[i])
            self._max_waves[i] = max(self._band_wavelengths[i])
            self._imp_waves[i] = set([self._min_waves[i], self._max_waves[i]])
            if len(self._transmissions[i]):
                new_wave = self._band_wavelengths[i][
                    np.argmax(self._transmissions[i])]
                self._imp_waves[i].add(new_wave)
            elif len(self._band_areas[i]):
                new_wave = self._band_wavelengths[i][
                    np.argmax(self._band_areas[i])]
                self._imp_waves[i].add(new_wave)
            self._imp_waves[i] = list(sorted(self._imp_waves[i]))
            bc = bc + 1

        if self._pool.is_master():
            prt.message('band_load_complete', inline=True)
Ejemplo n.º 3
0
    def fit_data(self,
                 event_name='',
                 method=None,
                 iterations=None,
                 frack_step=20,
                 num_walkers=None,
                 num_temps=1,
                 burn=None,
                 post_burn=None,
                 fracking=True,
                 gibbs=False,
                 pool=None,
                 output_path='',
                 suffix='',
                 write=False,
                 upload=False,
                 upload_token='',
                 check_upload_quality=True,
                 convergence_type=None,
                 convergence_criteria=None,
                 save_full_chain=False,
                 extra_outputs=None):
        """Fit the data for a given event.

        Fitting performed using a combination of emcee and fracking.
        """
        if self._speak:
            speak('Fitting ' + event_name, self._speak)
        from mosfit.__init__ import __version__
        global model
        model = self._model
        prt = self._printer

        upload_model = upload and iterations > 0

        if pool is not None:
            self._pool = pool

        if upload:
            try:
                import dropbox
            except ImportError:
                if self._test:
                    pass
                else:
                    prt.message('install_db', error=True)
                    raise

        if not self._pool.is_master():
            try:
                self._pool.wait()
            except (KeyboardInterrupt, SystemExit):
                pass
            return (None, None, None)

        self._method = method

        if self._method == 'nester':
            self._sampler = Nester(self, model, iterations, burn, post_burn,
                                   num_walkers, convergence_criteria,
                                   convergence_type, gibbs, fracking,
                                   frack_step)
        else:
            self._sampler = Ensembler(self, model, iterations, burn, post_burn,
                                      num_temps, num_walkers,
                                      convergence_criteria, convergence_type,
                                      gibbs, fracking, frack_step)

        self._sampler.run(self._walker_data)

        prt.message('constructing')

        if write:
            if self._speak:
                speak(prt._strings['saving_output'], self._speak)

        if self._event_path:
            entry = Entry.init_from_file(catalog=None,
                                         name=self._event_name,
                                         path=self._event_path,
                                         merge=False,
                                         pop_schema=False,
                                         ignore_keys=[ENTRY.MODELS],
                                         compare_to_existing=False)
            new_photometry = []
            for photo in entry.get(ENTRY.PHOTOMETRY, []):
                if PHOTOMETRY.REALIZATION not in photo:
                    new_photometry.append(photo)
            if len(new_photometry):
                entry[ENTRY.PHOTOMETRY] = new_photometry
        else:
            entry = Entry(name=self._event_name)

        uentry = Entry(name=self._event_name)
        data_keys = set()
        for task in model._call_stack:
            if model._call_stack[task]['kind'] == 'data':
                data_keys.update(
                    list(model._call_stack[task].get('keys', {}).keys()))
        entryhash = entry.get_hash(keys=list(sorted(list(data_keys))))

        # Accumulate all the sources and add them to each entry.
        sources = []
        for root in model._references:
            for ref in model._references[root]:
                sources.append(entry.add_source(**ref))
        sources.append(entry.add_source(**self._DEFAULT_SOURCE))
        source = ','.join(sources)

        usources = []
        for root in model._references:
            for ref in model._references[root]:
                usources.append(uentry.add_source(**ref))
        usources.append(uentry.add_source(**self._DEFAULT_SOURCE))
        usource = ','.join(usources)

        model_setup = OrderedDict()
        for ti, task in enumerate(model._call_stack):
            task_copy = deepcopy(model._call_stack[task])
            if (task_copy['kind'] == 'parameter'
                    and task in model._parameter_json):
                task_copy.update(model._parameter_json[task])
            model_setup[task] = task_copy
        modeldict = OrderedDict([(MODEL.NAME, model._model_name),
                                 (MODEL.SETUP, model_setup),
                                 (MODEL.CODE, 'MOSFiT'),
                                 (MODEL.DATE, time.strftime("%Y/%m/%d")),
                                 (MODEL.VERSION, __version__),
                                 (MODEL.SOURCE, source)])

        self._sampler.prepare_output(check_upload_quality, upload)

        self._sampler.append_output(modeldict)

        umodeldict = deepcopy(modeldict)
        umodeldict[MODEL.SOURCE] = usource
        modelhash = get_model_hash(umodeldict,
                                   ignore_keys=[MODEL.DATE, MODEL.SOURCE])
        umodelnum = uentry.add_model(**umodeldict)

        if self._sampler._upload_model is not None:
            upload_model = self._sampler._upload_model

        modelnum = entry.add_model(**modeldict)

        samples, probs, weights = self._sampler.get_samples()

        extras = OrderedDict()
        samples_to_plot = self._sampler._nwalkers

        if isinstance(self._sampler, Nester):
            icdf = np.cumsum(np.concatenate(([0.0], weights)))
            draws = np.random.rand(samples_to_plot)
            indices = np.searchsorted(icdf, draws) - 1
        else:
            indices = list(range(samples_to_plot))

        ri = 0
        selected_extra = False
        for xi, x in enumerate(samples):
            ri = ri + 1
            prt.message('outputting_walker', [ri, len(samples)],
                        inline=True,
                        min_time=0.2)
            if xi in indices:
                output = model.run_stack(x, root='output')
                if extra_outputs is not None:
                    if not extra_outputs and not selected_extra:
                        extra_options = list(output.keys())
                        prt.message('available_keys')
                        for opt in extra_options:
                            prt.prt('- {}'.format(opt))
                        selected_extra = True
                    for key in extra_outputs:
                        new_val = output.get(key, [])
                        new_val = all_to_list(new_val)
                        extras.setdefault(key, []).append(new_val)
                for i in range(len(output['times'])):
                    if not np.isfinite(output['model_observations'][i]):
                        continue
                    photodict = {
                        PHOTOMETRY.TIME:
                        output['times'][i] + output['min_times'],
                        PHOTOMETRY.MODEL: modelnum,
                        PHOTOMETRY.SOURCE: source,
                        PHOTOMETRY.REALIZATION: str(ri)
                    }
                    if output['observation_types'][i] == 'magnitude':
                        photodict[PHOTOMETRY.BAND] = output['bands'][i]
                        photodict[PHOTOMETRY.
                                  MAGNITUDE] = output['model_observations'][i]
                        photodict[PHOTOMETRY.
                                  E_MAGNITUDE] = output['model_variances'][i]
                    elif output['observation_types'][i] == 'magcount':
                        if output['model_observations'][i] == 0.0:
                            continue
                        photodict[PHOTOMETRY.BAND] = output['bands'][i]
                        photodict[PHOTOMETRY.
                                  COUNT_RATE] = output['model_observations'][i]
                        photodict[PHOTOMETRY.
                                  E_COUNT_RATE] = output['model_variances'][i]
                        photodict[PHOTOMETRY.MAGNITUDE] = -2.5 * np.log10(
                            output['model_observations']
                            [i]) + output['all_zeropoints'][i]
                        photodict[PHOTOMETRY.E_UPPER_MAGNITUDE] = 2.5 * (
                            np.log10(output['model_observations'][i] +
                                     output['model_variances'][i]) -
                            np.log10(output['model_observations'][i]))
                        if (output['model_variances'][i] >
                                output['model_observations'][i]):
                            photodict[PHOTOMETRY.UPPER_LIMIT] = True
                        else:
                            photodict[PHOTOMETRY.E_LOWER_MAGNITUDE] = 2.5 * (
                                np.log10(output['model_observations'][i]) -
                                np.log10(output['model_observations'][i] -
                                         output['model_variances'][i]))
                    elif output['observation_types'][i] == 'fluxdensity':
                        photodict[PHOTOMETRY.FREQUENCY] = output[
                            'frequencies'][i] * frequency_unit('GHz')
                        photodict[PHOTOMETRY.FLUX_DENSITY] = output[
                            'model_observations'][i] * flux_density_unit('µJy')
                        photodict[PHOTOMETRY.E_LOWER_FLUX_DENSITY] = (
                            photodict[PHOTOMETRY.FLUX_DENSITY] -
                            (10.0**
                             (np.log10(photodict[PHOTOMETRY.FLUX_DENSITY]) -
                              output['model_variances'][i] / 2.5)) *
                            flux_density_unit('µJy'))
                        photodict[PHOTOMETRY.E_UPPER_FLUX_DENSITY] = (
                            10.0**(np.log10(photodict[PHOTOMETRY.FLUX_DENSITY])
                                   + output['model_variances'][i] / 2.5) *
                            flux_density_unit('µJy') -
                            photodict[PHOTOMETRY.FLUX_DENSITY])
                        photodict[PHOTOMETRY.U_FREQUENCY] = 'GHz'
                        photodict[PHOTOMETRY.U_FLUX_DENSITY] = 'µJy'
                    elif output['observation_types'][i] == 'countrate':
                        photodict[PHOTOMETRY.
                                  COUNT_RATE] = output['model_observations'][i]
                        photodict[PHOTOMETRY.E_LOWER_COUNT_RATE] = (
                            photodict[PHOTOMETRY.COUNT_RATE] -
                            (10.0**(np.log10(photodict[PHOTOMETRY.COUNT_RATE])
                                    - output['model_variances'][i] / 2.5)))
                        photodict[PHOTOMETRY.E_UPPER_COUNT_RATE] = (
                            10.0**(np.log10(photodict[PHOTOMETRY.COUNT_RATE]) +
                                   output['model_variances'][i] / 2.5) -
                            photodict[PHOTOMETRY.COUNT_RATE])
                        photodict[PHOTOMETRY.U_COUNT_RATE] = 's^-1'
                    if ('model_upper_limits' in output
                            and output['model_upper_limits'][i]):
                        photodict[PHOTOMETRY.UPPER_LIMIT] = bool(
                            output['model_upper_limits'][i])
                    if self._limiting_magnitude is not None:
                        photodict[PHOTOMETRY.SIMULATED] = True
                    if 'telescopes' in output and output['telescopes'][i]:
                        photodict[
                            PHOTOMETRY.TELESCOPE] = output['telescopes'][i]
                    if 'systems' in output and output['systems'][i]:
                        photodict[PHOTOMETRY.SYSTEM] = output['systems'][i]
                    if 'bandsets' in output and output['bandsets'][i]:
                        photodict[PHOTOMETRY.BAND_SET] = output['bandsets'][i]
                    if 'instruments' in output and output['instruments'][i]:
                        photodict[
                            PHOTOMETRY.INSTRUMENT] = output['instruments'][i]
                    if 'modes' in output and output['modes'][i]:
                        photodict[PHOTOMETRY.MODE] = output['modes'][i]
                    entry.add_photometry(compare_to_existing=False,
                                         check_for_dupes=False,
                                         **photodict)

                    uphotodict = deepcopy(photodict)
                    uphotodict[PHOTOMETRY.SOURCE] = umodelnum
                    uentry.add_photometry(compare_to_existing=False,
                                          check_for_dupes=False,
                                          **uphotodict)
            else:
                output = model.run_stack(x, root='objective')

            parameters = OrderedDict()
            derived_keys = set()
            pi = 0
            for ti, task in enumerate(model._call_stack):
                # if task not in model._free_parameters:
                #     continue
                if model._call_stack[task]['kind'] != 'parameter':
                    continue
                paramdict = OrderedDict(
                    (('latex', model._modules[task].latex()),
                     ('log', model._modules[task].is_log())))
                if task in model._free_parameters:
                    poutput = model._modules[task].process(
                        **{'fraction': x[pi]})
                    value = list(poutput.values())[0]
                    paramdict['value'] = value
                    paramdict['fraction'] = x[pi]
                    pi = pi + 1
                else:
                    if output.get(task, None) is not None:
                        paramdict['value'] = output[task]
                parameters.update({model._modules[task].name(): paramdict})
                # Dump out any derived parameter keys
                derived_keys.update(model._modules[task].get_derived_keys())

            for key in list(sorted(list(derived_keys))):
                if (output.get(key, None) is not None
                        and key not in parameters):
                    parameters.update({key: {'value': output[key]}})

            realdict = {REALIZATION.PARAMETERS: parameters}
            if probs is not None:
                realdict[REALIZATION.SCORE] = str(probs[xi])
            else:
                realdict[REALIZATION.SCORE] = str(
                    ln_likelihood(x) + ln_prior(x))
            realdict[REALIZATION.ALIAS] = str(ri)
            realdict[REALIZATION.WEIGHT] = str(weights[xi])
            entry[ENTRY.MODELS][0].add_realization(check_for_dupes=False,
                                                   **realdict)
            urealdict = deepcopy(realdict)
            uentry[ENTRY.MODELS][0].add_realization(check_for_dupes=False,
                                                    **urealdict)
        prt.message('all_walkers_written', inline=True)

        entry.sanitize()
        oentry = {self._event_name: entry._ordered(entry)}
        uentry.sanitize()
        ouentry = {self._event_name: uentry._ordered(uentry)}

        uname = '_'.join([self._event_name, entryhash, modelhash])

        if output_path and not os.path.exists(output_path):
            os.makedirs(output_path)

        if not os.path.exists(model.get_products_path()):
            os.makedirs(model.get_products_path())

        if write:
            prt.message('writing_complete')
            with open_atomic(
                    os.path.join(model.get_products_path(), 'walkers.json'),
                    'w') as flast, open_atomic(
                        os.path.join(
                            model.get_products_path(), self._event_name +
                            (('_' + suffix) if suffix else '') + '.json'),
                        'w') as feven:
                entabbed_json_dump(oentry, flast, separators=(',', ':'))
                entabbed_json_dump(oentry, feven, separators=(',', ':'))

            if save_full_chain:
                prt.message('writing_full_chain')
                with open_atomic(
                        os.path.join(model.get_products_path(), 'chain.json'),
                        'w') as flast, open_atomic(
                            os.path.join(
                                model.get_products_path(),
                                self._event_name + '_chain' +
                                (('_' + suffix) if suffix else '') + '.json'),
                            'w') as feven:
                    entabbed_json_dump(self._sampler._all_chain.tolist(),
                                       flast,
                                       separators=(',', ':'))
                    entabbed_json_dump(self._sampler._all_chain.tolist(),
                                       feven,
                                       separators=(',', ':'))

            if extra_outputs is not None:
                prt.message('writing_extras')
                with open_atomic(
                        os.path.join(model.get_products_path(), 'extras.json'),
                        'w') as flast, open_atomic(
                            os.path.join(
                                model.get_products_path(),
                                self._event_name + '_extras' +
                                (('_' + suffix) if suffix else '') + '.json'),
                            'w') as feven:
                    entabbed_json_dump(extras, flast, separators=(',', ':'))
                    entabbed_json_dump(extras, feven, separators=(',', ':'))

            prt.message('writing_model')
            with open_atomic(
                    os.path.join(model.get_products_path(), 'upload.json'),
                    'w') as flast, open_atomic(
                        os.path.join(
                            model.get_products_path(), uname +
                            (('_' + suffix) if suffix else '') + '.json'),
                        'w') as feven:
                entabbed_json_dump(ouentry, flast, separators=(',', ':'))
                entabbed_json_dump(ouentry, feven, separators=(',', ':'))

        if upload_model:
            prt.message('ul_fit', [entryhash, self._sampler._modelhash])
            upayload = entabbed_json_dumps(ouentry, separators=(',', ':'))
            try:
                dbx = dropbox.Dropbox(upload_token)
                dbx.files_upload(upayload.encode(),
                                 '/' + uname + '.json',
                                 mode=dropbox.files.WriteMode.overwrite)
                prt.message('ul_complete')
            except Exception:
                if self._test:
                    pass
                else:
                    raise

        if upload:
            for ce in self._converter.get_converted():
                dentry = Entry.init_from_file(catalog=None,
                                              name=ce[0],
                                              path=ce[1],
                                              merge=False,
                                              pop_schema=False,
                                              ignore_keys=[ENTRY.MODELS],
                                              compare_to_existing=False)

                dentry.sanitize()
                odentry = {ce[0]: uentry._ordered(dentry)}
                dpayload = entabbed_json_dumps(odentry, separators=(',', ':'))
                text = prt.message('ul_devent', [ce[0]], prt=False)
                ul_devent = prt.prompt(text, kind='bool', message=False)
                if ul_devent:
                    dpath = '/' + slugify(
                        ce[0] + '_' + dentry[ENTRY.SOURCES][0].get(
                            SOURCE.BIBCODE, dentry[ENTRY.SOURCES][0].get(
                                SOURCE.NAME, 'NOSOURCE'))) + '.json'
                    try:
                        dbx = dropbox.Dropbox(upload_token)
                        dbx.files_upload(
                            dpayload.encode(),
                            dpath,
                            mode=dropbox.files.WriteMode.overwrite)
                        prt.message('ul_complete')
                    except Exception:
                        if self._test:
                            pass
                        else:
                            raise

        return (entry, samples, probs)
Ejemplo n.º 4
0
    def fetch(self, event_list, offline=False):
        """Fetch a list of events from the open catalogs."""
        dir_path = os.path.dirname(os.path.realpath(__file__))
        prt = self._printer

        levent_list = listify(event_list)
        events = [None for x in levent_list]

        catalogs = OrderedDict([
            (x, self._catalogs[x]) for x in self._catalogs
            if x not in self._excluded_catalogs])

        for ei, event in enumerate(levent_list):
            if not event:
                continue
            events[ei] = OrderedDict()
            path = ''
            # If the event name ends in .json, assume event is a path.
            if event.endswith('.json'):
                path = event
                events[ei]['name'] = event.replace('.json',
                                                   '').split('/')[-1]

            # If not (or the file doesn't exist), download from an open
            # catalog.
            if not path or not os.path.exists(path):
                names_paths = [
                    os.path.join(dir_path, 'cache', x +
                                 '.names.min.json') for x in catalogs]
                input_name = event.replace('.json', '')
                if offline:
                    prt.message('event_interp', [input_name])
                else:
                    prt.message('dling_aliases', [input_name])
                    for ci, catalog in enumerate(catalogs):
                        try:
                            response = get_url_file_handle(
                                catalogs[catalog]['json'] +
                                '/names.min.json',
                                timeout=10)
                        except Exception:
                            prt.message(
                                'cant_dl_names', [catalog], warning=True)
                        else:
                            with open_atomic(
                                    names_paths[ci], 'wb') as f:
                                shutil.copyfileobj(response, f)
                names = OrderedDict()
                for ci, catalog in enumerate(catalogs):
                    if os.path.exists(names_paths[ci]):
                        with open(names_paths[ci], 'r') as f:
                            names[catalog] = json.load(
                                f, object_pairs_hook=OrderedDict)
                    else:
                        prt.message('cant_read_names', [catalog],
                                    warning=True)
                        if offline:
                            prt.message('omit_offline')
                        continue

                    if input_name in names[catalog]:
                        events[ei]['name'] = input_name
                        events[ei]['catalog'] = catalog
                    else:
                        for name in names[catalog]:
                            if (input_name in names[catalog][name] or
                                    'SN' + input_name in
                                    names[catalog][name]):
                                events[ei]['name'] = name
                                events[ei]['catalog'] = catalog
                                break

                if not events[ei].get('name', None):
                    for ci, catalog in enumerate(catalogs):
                        namekeys = []
                        for name in names[catalog]:
                            namekeys.extend(names[catalog][name])
                        namekeys = list(sorted(set(namekeys)))
                        matches = get_close_matches(
                            event, namekeys, n=5, cutoff=0.8)
                        # matches = []
                        if len(matches) < 5 and is_number(event[0]):
                            prt.message('pef_ext_search')
                            snprefixes = set(('SN19', 'SN20'))
                            for name in names[catalog]:
                                ind = re.search("\d", name)
                                if ind and ind.start() > 0:
                                    snprefixes.add(name[:ind.start()])
                            snprefixes = list(sorted(snprefixes))
                            for prefix in snprefixes:
                                testname = prefix + event
                                new_matches = get_close_matches(
                                    testname, namekeys, cutoff=0.95,
                                    n=1)
                                if (len(new_matches) and
                                        new_matches[0] not in matches):
                                    matches.append(new_matches[0])
                                if len(matches) == 5:
                                    break
                        if len(matches):
                            if self._test:
                                response = matches[0]
                            else:
                                response = prt.prompt(
                                    'no_exact_match',
                                    kind='select',
                                    options=matches,
                                    none_string=(
                                        'None of the above, ' +
                                        ('skip this event.' if
                                         ci == len(catalogs) - 1
                                         else
                                         'try the next catalog.')))
                            if response:
                                for name in names[catalog]:
                                    if response in names[
                                            catalog][name]:
                                        events[ei]['name'] = name
                                        events[ei]['catalog'] = catalog
                                        break
                                if events[ei]['name']:
                                    break

                if not events[ei].get('name', None):
                    prt.message('no_event_by_name')
                    events[ei]['name'] = input_name
                    continue
                urlname = events[ei]['name'] + '.json'
                name_path = os.path.join(dir_path, 'cache', urlname)

                if offline:
                    prt.message('cached_event', [
                        events[ei]['name'], events[ei]['catalog']])
                else:
                    prt.message('dling_event', [
                        events[ei]['name'], events[ei]['catalog']])
                    try:
                        response = get_url_file_handle(
                            catalogs[events[ei]['catalog']][
                                'json'] + '/json/' + urlname,
                            timeout=10)
                    except Exception:
                        prt.message('cant_dl_event', [
                            events[ei]['name']], warning=True)
                    else:
                        with open_atomic(name_path, 'wb') as f:
                            shutil.copyfileobj(response, f)
                path = name_path

            if os.path.exists(path):
                events[ei]['path'] = path
                if self._open_in_browser:
                    webbrowser.open(
                        catalogs[events[ei]['catalog']]['web'] +
                        events[ei]['name'])
                with open(path, 'r') as f:
                    events[ei]['data'] = json.load(
                        f, object_pairs_hook=OrderedDict)
                prt.message('event_file', [path], wrapped=True)
            else:
                prt.message('no_data', [
                    events[ei]['name'],
                    '/'.join(catalogs.keys())])
                if offline:
                    prt.message('omit_offline')
                raise RuntimeError

        return events