コード例 #1
0
ファイル: parsegenosets.py プロジェクト: tritiumoxide/snappy
def parse():
    for page in utils.iter_dump('Is a genoset'):
        ptitle = page.title.cdata
        ptext = page.revision.text.cdata
        name = utils.normalize_name(ptitle.split('/')[0])

        if '/criteria' in ptitle:
            ptext = ptext.replace(' ', '')
            gsinfo = {
                'c': ptext,
            }
            yield (name, gsinfo)
        else:
            paramMap = {
                'repute': 'r',
                'magnitude': 'm',
                'summary': 's',
            }

            parsed = mwparserfromhell.parse(ptext)
            gsinfo = utils.extract_parameters(parsed,
                                              'genoset',
                                              paramMap,
                                              delete=True)

            # gsinfo['d'] = str(parsed)
            if 'm' in gsinfo:
                gsinfo['m'] = utils.filter_value(float, gsinfo['m'])

            yield (name, gsinfo)
コード例 #2
0
    def get(self):

        sc_api_url = "http://api.soundcloud.com"
        callback = self.request.get('callback')
        api_parameters = utils.extract_parameters(self.request.uri)
        if api_parameters:
            self.response.headers[
                "Content-Type"] = "text/javascript; charset=utf-8"
            self.response.headers[
                "Cache-Control"] = "max-age=10800, must-revalidate"  # testing force client caching, works in ff3 at least
            parameters_hash = str(hash(api_parameters))
            hit = memcache.get(parameters_hash)
            if hit is None:
                try:
                    response = urlfetch.fetch(
                        url=sc_api_url + urllib.quote_plus(api_parameters),
                        method=urlfetch.GET,
                        headers={
                            'Content-Type': 'text/javascript; charset=utf-8'
                        })
                    memcache.set(parameters_hash, response.content, 10800)
                    utils.print_with_callback(callback, response.content,
                                              self.response)
                except:
                    utils.print_with_callback(callback,
                                              utils.status_code_json(408),
                                              self.response)
            else:
                utils.print_with_callback(callback, hit, self.response)
コード例 #3
0
ファイル: parsegenotypes.py プロジェクト: tritiumoxide/snappy
def parse_genotypes():
    genotypeRegex = re.compile(r'((?:rs|i)[0-9]+)\(([^\)]+)\)')

    for page in utils.iter_dump('Is a genotype'):
        ptitle = page.title.cdata
        ptext = page.revision.text.cdata
        name = utils.normalize_name(ptitle)

        # Parse genotype
        matches = genotypeRegex.match(name)

        if not matches:
            print('Genotype {} invalid'.format(name))
            continue

        # Extract info
        paramMap = {
            'repute': 'r',
            'magnitude': 'm',
            'summary': 's',
        }

        parsed = mwparserfromhell.parse(ptext)

        snp = matches.group(1)
        genotype = matches.group(2)

        if ptext.startswith('#REDIRECT'):
            target = utils.normalize_name(parsed.filter_wikilinks()[0].title)
            targetgt = genotypeRegex.match(target)
            if not targetgt:
                print('Target genotype {} invalid'.format(target))
                continue

            snpinfo = {}
            snpinfo[genotype] = targetgt.group(2)
            yield (snp, snpinfo)
            continue

        genotypeinfo = utils.extract_parameters(parsed, 'genotype', paramMap, delete=True)

        if 'D' in genotype or ':' in genotype:
            genotypeinfo['o'] = genotype
            genotype = genotype.replace('D', '-').replace(':', ';')

        if ';' not in genotype:
            print('Ignoring {}'.format(name))

        if 'm' in genotypeinfo:
            if genotypeinfo['m'] == '0':
                genotypeinfo.pop('m', None)
            else:
                genotypeinfo['m'] = utils.filter_value(float, genotypeinfo['m'])

        snpinfo = {}
        snpinfo[genotype] = genotypeinfo

        yield (snp, snpinfo)
コード例 #4
0
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=10,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        if self.dataset == 'regression':
            self.data = toy.generate_regression_data(80, 200)
            (self.train_data,
             self.train_targets), (self.test_data,
                                   self.test_targets) = self.data
        elif self.dataset == 'classification':
            self.train_data, self.train_targets = toy.generate_classification_data(
                100)
            self.test_data, self.test_targets = toy.generate_classification_data(
                200)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
        else:
            raise NotImplementedError

        models = [base_model_fn().cuda() for _ in range(num_particles)]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3
        }])

        if self.dataset == 'regression':
            self.loss_fn = torch.nn.MSELoss()
        elif self.dataset == 'classification':
            self.loss_fn = torch.nn.CrossEntropyLoss()
        self.kernel_width_averager = Averager(shape=())
コード例 #5
0
 def get(self):
   sc_api_url = "http://api.soundcloud.com/"
   callback = self.request.get('callback')
   api_parameters = utils.extract_parameters(self.request.uri)
   if api_parameters:
     self.response.headers["Content-Type"] = "text/javascript; charset=utf-8"
     self.response.headers["Cache-Control"] = "max-age=3600, must-revalidate" # testing force client caching, works in ff3 at least
     parameters_hash = str(hash(api_parameters))    
     hit = memcache.get(parameters_hash)
     if hit is None:
       try:
         response = urlfetch.fetch(url = sc_api_url + api_parameters,method=urlfetch.GET, headers={'Content-Type': 'text/javascript; charset=utf-8'})
         memcache.set(parameters_hash, response.content, 3600)
         utils.print_with_callback(callback, response.content. self.response)
       except:
         utils.print_with_callback(callback, utils.status_code_json(408), self.response)
     else:
       utils.print_with_callback(callback, hit, self.response)
コード例 #6
0
def parse_snps():
    # TODO: {{ClinVar}}

    for page in utils.iter_dump('Is a snp'):
        ptitle = page.title.cdata
        ptext = page.revision.text.cdata
        name = ptitle[0].lower() + ptitle[1:]

        paramMap = {
            'stabilizedOrientation': 'orientation',
            'chromosome': 'chromosome',
            'position': 'position',
            'referenceAllele': 'referenceAllele',
            'missenseAllele': 'missenseAllele',
            'assembly': 'assembly',
            'genomeBuild': 'genomeBuild',
            'dbSNPBuild': 'dbSNPBuild',
            'summary': 'summary',
            'gene_s': 'genes',
            'gene': 'gene',
        }

        parsed = mwparserfromhell.parse(ptext)
        snpinfo = {}
        snpinfo.update(
            utils.extract_parameters(parsed, ['rsnum', '23andMe SNP'],
                                     paramMap))

        if 'position' in snpinfo:
            snpinfo['position'] = utils.filter_value(int, snpinfo['position'])

        if 'gene_s' in snpinfo:
            snpinfo['genes'] = [
                g.strip() for g in snpinfo['gene_s'].split(',')
            ]
            snpinfo.pop('gene_s', None)
            snpinfo.pop('gene', None)
        elif 'gene' in snpinfo:
            snpinfo['genes'] = [snpinfo['gene'].strip()]
            snpinfo.pop('gene', None)

        for template in parsed.ifilter_templates(recursive=False):
            if utils.normalize_name(template.name.strip_code()) not in [
                    'rsnum', '23andMe SNP'
            ]:
                continue

            snpinfo['genotypes'] = []
            for n in range(1, 9):
                param = 'geno' + str(n)
                if template.has(param):
                    g = template.get(param).value.strip_code().strip().replace(
                        ':', ';')
                    matches = re.match('\(([AGCTDIN-]*) *\;([AGCTDIN-]*) *\)',
                                       g)

                    if not matches:
                        print('{}: Genotype {} invalid'.format(name, g))
                        continue

                    allele1 = matches.group(1).replace('D', '-')
                    allele2 = matches.group(2).replace('D', '-')

                    if allele1 == '':
                        allele1 = '-'
                    if allele2 == '':
                        allele2 = '-'

                    # genotypePage = '{}({};{})'.format(name, allele1, allele2)

                    snpinfo['genotypes'].append(allele1 + allele2)

            parsed.remove(template, recursive=False)
            break

        # snpinfo['details'] = str(parsed)

        yield (name, snpinfo)
コード例 #7
0
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=10,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        if self.dataset == 'mnist':
            self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist(
                split=True)
        elif self.dataset == 'cifar10':
            self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10(
                split=True)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
            return_activations = False
        elif kernel_fn == 'cka':
            self.kernel = kernel_cka
            return_activations = True
        else:
            raise NotImplementedError

        models = [
            base_model_fn(num_classes=6,
                          return_activations=return_activations).cuda()
            for _ in range(num_particles)
        ]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3,
            'weight_decay': 1e-4
        }])

        if resume:
            print('resuming from epoch {}'.format(resume_epoch))
            d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format(
                self.dataset, model_id, resume_epoch))
            for model, sd in zip(self.models, d['models']):
                model.load_state_dict(sd)
            self.param_set = d['params']
            self.state_dict = d['state_dict']
            self.optimizer = torch.optim.Adam([{
                'params': self.param_set,
                'lr': resume_lr,
                'weight_decay': 1e-4
            }])
            self.start_epoch = resume_epoch
        else:
            self.start_epoch = 0

        self.activation_length = self.models[0].activation_length
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.kernel_width_averager = Averager(shape=())
コード例 #8
0
        # this decides whether it's a logistic or semi-logistic policy
        tmp_config["keep_positive"] = "semi" in policy_type
        pi = LogisticPolicy(init_d, featuremap=fm, config=tmp_config)
        logger.info(f"Use {data_strategy}...")
        strat = strategies.IPSStrategy(td, config, data_strategy)
    elif policy_type == "deterministic_threshold":
        pi = DeterministicThreshold(init_y, cost, featuremap=fm)
        logger.info(f"Use {data_strategy}...")
        strat = strategies.PredictiveStrategy(td, config, data_strategy)
    else:
        raise RuntimeError(f"Unknown policy type {policy_type}")

    # Train the policy with the chosen strategy
    deployed = strat.train(pi)
    # Try to extrac the parameters theta of all deployed policies (if existent)
    all_theta = utils.extract_parameters(deployed["pis"])
    # Policies cannot be stored as arrays in .npz and are no longer needed
    del deployed["pis"]

    logger.info(
        f"Write results for {policy_type} trained on {data_strategy}...")
    suffix = policy_type + "_" + data_strategy
    fname = result_data_prefix + suffix + ".npz"
    np.savez(
        fname,
        thetas=all_theta,
        data_seeds=config["optimization"]["data_seeds"],
        **deployed,
    )

# -------------------------------------------------------------------------
コード例 #9
0
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=50,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        self._use_wandb = False
        self._save_model = False

        if self.dataset == 'mnist':
            self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist(
                split=True)
        elif self.dataset == 'cifar10':
            self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10(
                split=True)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
        else:
            raise NotImplementedError

        models = [
            base_model_fn(num_classes=6).cuda() for _ in range(num_particles)
        ]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3,
            'weight_decay': 1e-4
        }])

        if resume:
            print('resuming from epoch {}'.format(resume_epoch))
            d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format(
                self.dataset, model_id, resume_epoch))
            for model, sd in zip(self.models, d['models']):
                model.load_state_dict(sd)
            self.param_set = d['params']
            self.state_dict = d['state_dict']
            self.optimizer = torch.optim.Adam([{
                'params': self.param_set,
                'lr': resume_lr,
                'weight_decay': 1e-4
            }])
            self.start_epoch = resume_epoch
        else:
            self.start_epoch = 0

        loss_type = 'ce'
        if loss_type == 'ce':
            self.loss_fn = torch.nn.CrossEntropyLoss()
        elif loss_type == 'kliep':
            self.loss_fn = MattLoss().get_loss_dict()['kliep']
        self.kernel_width_averager = Averager(shape=())

        if self._use_wandb:
            wandb.init(project="open-category-experiments",
                       name="SVGD {}".format(self.dataset))
            for model in models:
                wandb.watch(model)
            config = wandb.config
            config.algo = algo
            config.dataset = dataset
            config.kernel_fn = kernel_fn
            config.num_particles = num_particles
            config.loss_fn = loss_type