Пример #1
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    # set training environment, e.g. distribution, cudnn_benchmark, random_seed for re-prodution
    env.set_env(cfg.env_config)

    # init logger before other steps
    logger = log.get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(True))

    if cfg.checkpoint_config is not None:
        # save satdet version in checkpoints as meta data
        cfg.checkpoint_config.meta = dict(satdet_version=__version__,
                                          config=cfg.text)

    model = build_detector(cfg.model,
                           train_cfg=cfg.train_cfg,
                           test_cfg=cfg.test_cfg)

    train_dataset = get_dataset(cfg.data.train)
    val_dataset = get_dataset(cfg.data.val)

    train_detector(model, [train_dataset, val_dataset], cfg, logger=logger)
Пример #2
0
def config() -> Config:
    c = Config()
    set_env(c, EXPAND)
    c.set_optimizer(lambda params: Adam(params, lr=2.5e-4, eps=1.0e-4))
    c.set_net_fn('actor-critic', a2c_conv())
    c.grad_clip = 0.5
    c.episode_log_freq = 100
    c.eval_deterministic = False
    return c
Пример #3
0
def config() -> Config:
    c = vae.patched_config()
    # vae parameters
    c.vae_loss_weight = 1.0
    c.vae_loss = vae.BetaVaeLoss(beta=4.0, decoder_type='categorical_binary')
    set_env(c, EXPAND)
    c.set_optimizer(lambda params: Adam(params, lr=2.5e-4, eps=1.0e-4))
    c.set_net_fn('actor-critic', net)
    c.grad_clip = 0.5
    c.episode_log_freq = 100
    c.eval_deterministic = False
    c.network_log_freq = 100
    return c
Пример #4
0
def setup(config):
    gecko_repo = repos.Gecko(config)
    git_gecko = gecko_repo.repo()
    wpt_repo = repos.WebPlatformTests(config)
    git_wpt = wpt_repo.repo()
    gh_wpt = gh.GitHub(config["web-platform-tests"]["github"]["token"],
                       config["web-platform-tests"]["repo"]["url"])

    bz = bug.Bugzilla(config)

    env.set_env(config, bz, gh_wpt)
    logger.info("Gecko repository: %s" % git_gecko.working_dir)
    logger.info("wpt repository: %s" % git_wpt.working_dir)
    logger.info("Tasks enabled: %s" %
                (", ".join(config["sync"]["enabled"].keys())))
    return git_gecko, git_wpt
Пример #5
0
def config() -> Config:
    c = vae.patched_config()
    # vae parameters
    set_env(c, EXPAND)
    c.vae_loss_weight = 1.0
    c.vae_loss = vae.GammaVaeLoss(gamma=200.0,
                                  capacity_start=0.0,
                                  capacity_max=25.0,
                                  num_epochs=c.max_steps // (100 * 8))
    c.set_optimizer(lambda params: Adam(params, lr=2.5e-4, eps=1.0e-4))
    c.set_net_fn('actor-critic', net)
    c.grad_clip = 0.5
    c.episode_log_freq = 100
    c.eval_deterministic = False
    c.network_log_freq = 100
    return c
Пример #6
0
import mysql.connector
from mysql.connector import errorcode
from env import set_env
import os

set_env()


class MySqlDb:
    def __init__(self, db_name):
        self.db_name = db_name
        self.conn = self.mysql_connect()
        self.table = ""

    def create_db(self, db_name):
        # CREATE DATABASE IF NOT EXISTS `pythonlogin` DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci USE `pythonlogin`;
        pass

    def mysql_connect(self):
        try:
            cnx = mysql.connector.connect(user=os.getenv('MYSQL_USER'),
                                          password=os.getenv('MYSQL_PASS'),
                                          host=os.getenv('MYSQL_HOST'),
                                          database=self.db_name)
            print("Connected")
            return cnx
        except mysql.connector.Error as err:
            if err.errno == errorcode.ER_ACCESS_DENIED_ERROR:
                print("Something is wrong with your user name or password")
            elif err.errno == errorcode.ER_BAD_DB_ERROR:
                print("Database does not exist")
Пример #7
0
def tsvbuild(json_path, gcsbucket, suffix, pairflag, tsv_name, default,
             metaflag):
    """builds a tsv file from a directory of paired files

    Retrieves location of pairs of files matching suffix and retrieves
    metadata information from the parent directory located in an xml file
    Assumes paired files differences occur after an underscore '_'.

    Args:
        json_path (str): path to the json credentials file for GCS access
        gcsbucket (str): google cloud bucket name, recursively searched
        suffix (str): file identifying pattern being searched
        tsv_name (str): filename or tsv file
        default (bool): Use the default credentials
        metaflag (bool): Find and write metadata
        pairflag (bool): Find and write File pair

    Returns:
        str: location of tsv file
    Notes:
        format of column separation marked by tabs
        SampleName; output; predictedinsertsize; readgroup; library_name;
        platformmodel; platform; sequencingcenter; Fastq1 ; Fastq2

    Dev:
        Add dictquery to list of links
    """
    exp_dict = {}
    # set google auth
    if not default:
        env.set_env('GOOGLE_APPLICATION_CREDENTIALS', json_path)
    header = True
    loop = asyncio.get_event_loop()
    for gcs_url in gcloudstorage.blob_generator(gcsbucket, suffix):
        meta_dict = {}
        meta_dict['File'] = gcs_url
        if pairflag:
            gcs_pairname, gcs_pairpath, accension = pathhandling.get_fileurl(
                url=gcs_url,
                sep=suffix[0],
                suffix=suffix,
                depth=0,
                pair=pairflag)
            if gcloudstorage.blob_exists(gcs_pairpath):
                meta_dict['File_2'] = gcs_pairpath
        else:
            # seeking to create a filename
            gcs_fileout, gcs_filepath, parent = pathhandling.get_fileurl(
                url=gcs_url, sep=suffix[0], suffix=suffix, depth=0, pair=True)
            meta_dict['output'] = gcs_fileout
        if metaflag:
            exp_name, exp_path, exp_folder = pathhandling.get_fileurl(
                url=gcs_url,
                sep=suffix[0],
                suffix='.experiment.xml',
                depth=1,
                pair=False)
            try:
                curr_dict = next(
                    dict_extract.dict_extract(value=accension,
                                              var=exp_dict[exp_name]))
            except KeyError:
                xmlfile = gcloudstorage.blob_download(exp_path)
                exp_dict[exp_name] = xmldictconv.xmldictconv(xmlfile)
                curr_dict = next(
                    dict_extract.dict_extract(value=accension,
                                              var=exp_dict[exp_name]))
            loop.run_until_complete(
                dictquery.dict_endpoints(input_dict=curr_dict,
                                         endpoint_dict=meta_dict))
        tsvwriter(tsv_name, meta_dict, header)
        header = False
    loop.close()
    if not default:
        env.unset_env('GOOGLE_APPLICATION_CREDENTIALS')
    return tsv_name