Exemplo n.º 1
0
def fit(dm, method='map', keys=gbd_keys(), iter=50000, burn=25000, thin=1, verbose=1,
        dbname='model.pickle'):
    """ Generate an estimate of the generic disease model parameters
    using maximum a posteriori liklihood (MAP) or Markov-chain Monte
    Carlo (MCMC)

    Parameters
    ----------
    dm : dismod3.DiseaseModel
      the object containing all the data, priors, and additional
      information (like input and output age-mesh)

    method : string, optional
      the parameter estimation method, either 'map' or 'mcmc'

    keys : list, optional
      a list of gbd keys for the parameters to fit; it can speed up
      computation to holding some parameters constant while allowing
      others to vary

    iter : int, optional
    burn : int, optional
    thin : int, optional
      parameters for the MCMC, which control how long it takes, and
      how accurate it is
    """
    if not hasattr(dm, 'vars'):
        print 'initializing model vars... ',
        dm.calc_effective_sample_size(dm.data)
        dm.vars = setup(dm, keys)
        print 'finished'

    if method == 'map':
        print 'initializing MAP object... ',
        map_method = 'fmin_powell'
        #map_method = 'fmin_l_bfgs_b'

        mc.MAP([dm.vars[k] for k in keys if k.find('incidence') != -1]).fit(method=map_method, iterlim=500, tol=.01, verbose=verbose)
        mc.MAP([dm.vars[k] for k in keys if k.find('remission') != -1]).fit(method=map_method, iterlim=500, tol=.01, verbose=verbose)
        mc.MAP([dm.vars[k] for k in keys if
                k.find('excess-mortality') != -1 or
                k.find('m') != -1 or
                k.find('mortality') != -1 or
                k.find('relative-risk') != -1 or
                k.find('bins') != -1]).fit(method=map_method, iterlim=500, tol=.01, verbose=verbose)
        mc.MAP([dm.vars[k] for k in keys if
                k.find('incidence') != -1 or
                k.find('bins') != -1 or
                k.find('prevalence') != -1]).fit(method=map_method, iterlim=500, tol=.01, verbose=verbose)
        mc.MAP([dm.vars[k] for k in keys if
                k.find('excess-mortality') != -1 or
                k.find('m') != -1 or
                k.find('mortality') != -1 or
                k.find('relative-risk') != -1 or
                k.find('bins') != -1 or
                k.find('prevalence') != -1]).fit(method=map_method, iterlim=500, tol=.01, verbose=verbose)

        dm.map = mc.MAP(dm.vars)
        print 'finished'

        try:
            dm.map.fit(method=map_method, iterlim=500, tol=.001, verbose=verbose)
        except KeyboardInterrupt:
            # if user cancels with cntl-c, save current values for "warm-start"
            pass
        
        for k in keys:
            try:
                val = dm.vars[k]['rate_stoch'].value
                dm.set_map(k, val)
            except KeyError:
                pass

    if method == 'norm_approx':
        dm.na = mc.NormApprox(dm.vars, eps=.0001)

        try:
            dm.na.fit(method='fmin_powell', iterlim=500, tol=.00001, verbose=verbose)
        except KeyboardInterrupt:
            # if user cancels with cntl-c, save current values for "warm-start"
            pass

        for k in keys:
            if dm.vars[k].has_key('rate_stoch'):
                dm.set_map(k, dm.vars[k]['rate_stoch'].value)

        try:
            dm.na.sample(1000, verbose=verbose)
            for k in keys:
                # TODO: rename 'rate_stoch' to something more appropriate
                if dm.vars[k].has_key('rate_stoch'):
                    rate_model.store_mcmc_fit(dm, k, dm.vars[k])
        except KeyboardInterrupt:
            # if user cancels with cntl-c, save current values for "warm-start"
            pass

                        
    elif method == 'mcmc':
        # make pymc warnings go to stdout
        import sys
        mc.warnings.warn = sys.stdout.write
        
        dm.mcmc = mc.MCMC(dm.vars, db='pickle', dbname=dbname)
        for k in keys:
            if 'dispersion_step_sd' in dm.vars[k]:
                dm.mcmc.use_step_method(mc.Metropolis, dm.vars[k]['log_dispersion'],
                                        proposal_sd=dm.vars[k]['dispersion_step_sd'])
            if 'age_coeffs_mesh_step_cov' in dm.vars[k]:
                dm.mcmc.use_step_method(mc.AdaptiveMetropolis, dm.vars[k]['age_coeffs_mesh'],
                                        cov=dm.vars[k]['age_coeffs_mesh_step_cov'], verbose=0)

        try:
            dm.mcmc.sample(iter=iter, thin=thin, burn=burn, verbose=verbose)
        except KeyboardInterrupt:
            # if user cancels with cntl-c, save current values for "warm-start"
            pass
        dm.mcmc.db.commit()

        for k in keys:
            t,r,y,s = type_region_year_sex_from_key(k)
            
            if t in ['incidence', 'prevalence', 'remission', 'excess-mortality', 'mortality', 'prevalence_x_excess-mortality']:
                import neg_binom_model
                neg_binom_model.store_mcmc_fit(dm, k, dm.vars[k])
            elif t in ['relative-risk', 'duration', 'incidence_x_duration']:
                import normal_model
                normal_model.store_mcmc_fit(dm, k, dm.vars[k])
Exemplo n.º 2
0
def fit(id, opts):
    fit_str = '(%d) %s %s %s' % (id, opts.region or '', opts.sex or '', opts.year or '')
    #tweet('fitting disease model %s' % fit_str)
    sys.stdout.flush()
    
    # update job status file
    if opts.log:
        if opts.type and not (opts.region and opts.sex and opts.year):
            dismod3.log_job_status(id, 'empirical_priors', opts.type, 'Running')
        elif opts.region and opts.sex and opts.year and not opts.type:
            dismod3.log_job_status(id, 'posterior', '%s--%s--%s' % (opts.region, opts.sex, opts.year), 'Running')

    dm = dismod3.get_disease_model(id)
    fit_str = '%s %s' % (dm.params['condition'], fit_str)

    sex_list = opts.sex and [ opts.sex ] or dismod3.gbd_sexes
    year_list = opts.year and [ opts.year ] or dismod3.gbd_years
    region_list = opts.region and [ opts.region ] or dismod3.gbd_regions
    keys = gbd_keys(region_list=region_list, year_list=year_list, sex_list=sex_list)

    # fit empirical priors, if type is specified
    if opts.type:
        fit_str += ' emp prior for %s' % opts.type
        #print 'beginning ', fit_str
        import dismod3.neg_binom_model as model

        dir = dismod3.settings.JOB_WORKING_DIR % id
        model.fit_emp_prior(dm, opts.type, dbname='%s/empirical_priors/pickle/dm-%d-emp_prior-%s.pickle' % (dir, id, opts.type))

    # if type is not specified, find consistient fit of all parameters
    else:
        import dismod3.gbd_disease_model as model

        # get the all-cause mortality data, and merge it into the model
        mort = dismod3.get_disease_model('all-cause_mortality')
        dm.data += mort.data

        # fit individually, if sex, year, and region are specified
        if opts.sex and opts.year and opts.region:
            dm.params['estimate_type'] = 'fit individually'

        # fit the model
        #print 'beginning ', fit_str
        dir = dismod3.settings.JOB_WORKING_DIR % id
        model.fit(dm, method='map', keys=keys, verbose=1)
        model.fit(dm, method='mcmc', keys=keys, iter=10000, thin=5, burn=5000, verbose=1,
                  dbname='%s/posterior/pickle/dm-%d-posterior-%s-%s-%s.pickle' % (dir, id, opts.region, opts.sex, opts.year))
        #model.fit(dm, method='mcmc', keys=keys, iter=1, thin=1, burn=0, verbose=1)

    # remove all keys that have not been changed by running this model
    for k in dm.params.keys():
        if type(dm.params[k]) == dict:
            for j in dm.params[k].keys():
                if not j in keys:
                    dm.params[k].pop(j)

    # post results to dismod_data_server
    # "dumb" error handling, in case post fails (try: except: sleep random time, try again, stop after 4 tries)
    from twill.errors import TwillAssertionError
    from urllib2 import URLError
    import random

    PossibleExceptions = [TwillAssertionError, URLError]
    try:
        url = dismod3.post_disease_model(dm)
    except PossibleExceptions:
        time.sleep(random.random()*30)
        try:
            url = dismod3.post_disease_model(dm)
        except PossibleExceptions:
            time.sleep(random.random()*30)
            try:
                url = dismod3.post_disease_model(dm)
            except PossibleExceptions:
                time.sleep(random.random()*30)
                url = dismod3.post_disease_model(dm)

    # form url to view results
    #if opts.sex and opts.year and opts.region:
    #    url += '/%s/%s/%s' % (opts.region, opts.year, opts.sex)
    #elif opts.region:
    #    url += '/%s' % opts.region

    # announce completion, and url to view results
    #tweet('%s fit complete %s' % (fit_str, url))
    sys.stdout.flush()

    # update job status file
    if opts.log:
        if opts.type and not (opts.region and opts.sex and opts.year):
            dismod3.log_job_status(id, 'empirical_priors', opts.type, 'Completed')
        elif opts.region and opts.sex and opts.year and not opts.type:
            dismod3.log_job_status(id, 'posterior', '%s--%s--%s' % (opts.region, opts.sex, opts.year), 'Completed')