示例#1
0
def create_app():
    """An application factory."""
    load_config()
    app = Flask(__name__)

    app.config.from_object(__name__)

    app.secret_key = b'_5#y2L"F4Q8z\n\xec]/'
    app.debug = True

    # register blue print
    register_blueprints(app)

    return app
示例#2
0
def copy_csv_from_s3_to_db(bucket_name, filepath, destination_table, db_name):
    """Copies an existing csv file from s3 to a specified DB.

    Args:
        bucket_name (str): name of an existing bucket in the configured aws account.
        filepath (str): remote path in s3 bucket to the file intended to be unloaded.
        destination_table (str): structure.table to use to unload the csv content.
        db_name (str): name of the db where the table will be unloaded.
    """
    from src.config import load_config

    if filepath not in list_files_in_s3_bucket(bucket_name):
        raise FileNotFoundError(
            f"The filepath specified '{filepath}' does not exist in the"
            f" bucket '{bucket_name}'")
    creds = boto3.Session().get_credentials()
    query = (
        f"COPY {destination_table} from 's3://{bucket_name}/{filepath}' credentials "
        f"'aws_access_key_id={creds.access_key};"
        f"aws_secret_access_key={creds.secret_key}' "
        f"ignoreheader 1 removequotes delimiter ',' region 'us-east-1'")
    config = load_config("credentials")
    if db_name not in config.keys():
        raise CredentialsError(f'Credentials "{db_name}" not found')
    credentials = {k: v for k, v in config.items() if type(v) is not dict}
    credentials = {**credentials, **config[db_name]}

    with connect(**credentials) as conn:
        conn.cursor().execute(query)
        conn.commit()
示例#3
0
    def from_config(cls, path: Path) -> 'Server':
        server_config, wallet_config, users = load_config(path)
        cls_ = cls(server_config.getint('port'), wallet_config)

        cls_.users.update(users)

        return cls_
示例#4
0
def post(sourcetype: str, eventtype: str, fields: dict):
    try:
        cfg = config.load_config()
        headers = {"Authorization": f"Splunk {cfg['token']}"}

        data = {
            "index": cfg['index'],
            "host": str(socket.gethostname()),
            "sourcetype": sourcetype,
            "event": eventtype,
            "fields": fields
        }

        # print('URL: ' + cfg['url'] + '\nHeaders: ' + json.dumps(headers) + '\nData: ' + json.dumps(data, indent=4))
        requests.packages.urllib3.disable_warnings(
            category=InsecureRequestWarning)
        response = requests.post(cfg['url'],
                                 headers=headers,
                                 json=data,
                                 verify=False)
        print(
            f"[{int(str(time.time()).split('.')[0])}] Uploaded {data['fields']['metric_name']} -> Response: {response.status_code}"
        )
        return response

    except Exception as e:
        print(str(e))
示例#5
0
def _parse_credentials(query):
    # Query files should specify associated credentials on first line
    with open(_get_query_path(query)) as f:
        header = f.readline().strip()

    config = load_config("credentials")
    keys = {k for k, v in config.items() if type(v) is dict}

    prefix = "-- Credentials: "
    if not header.startswith(prefix):
        msg = f'Query "{query}" missing header "{prefix}[{", ".join(keys)}]"'
        raise QueryError(msg)

    key = header[len(prefix):]

    if key not in keys:
        raise CredentialsError(f'Credentials "{key}" not found')

    credentials = {k: v for k, v in config.items() if type(v) is not dict}
    credentials = {**credentials, **config[key]}

    if not credentials["user"]:
        raise CredentialsError(f'Username for "{key}" not found')

    if not credentials["password"]:
        raise CredentialsError(f'Password for "{key}" not found')

    return credentials
 def test_load_config(self):
     config = load_config(["default"])
     self.assertDictEqual(
         {
             k: v
             for k, v in _INITIAL_CONFIG["groups"].items()
             if k is not "default_groups"
         }, config)
示例#7
0
def download_cache_string(cache_id):
    """
    Fetch cached data as a string. Returns none if the cache does not exist.
    """
    config = load_config()
    with tempfile.NamedTemporaryFile() as fd:
        config['cache_client'].download_cache(cache_id, fd.name)
        contents = fd.read().decode()
        print(f'Downloaded contents from cache ID {cache_id[0:16]}..')
        return contents
def set_env(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(args.threads)
    current_time = datetime.now().strftime('%b_%d_%H-%M-%S')
    model_args = load_config(args.model_name, args.dataset_str, args.optimizer)
    if args.model_name in ['DPGVAE', 'GVAE']:
        args.discriminator_ratio = 0
    return args, model_args, current_time
示例#9
0
def _add_item(pipeline_name):
    url = request.form.get("url")
    config = load_config()
    pipeline = config.pipeline[pipeline_name]
    for p in pipeline.pull:
        pull = pull_services[p.service]
        item_id = pull.parse_item_id(url)
        if item_id is None:
            continue
        ItemInfo.add_index(IndexItem(service=p.service, item_id=item_id),
                           [pipeline_name])
        ItemInfo.set_status(p.service, item_id, TaskStage.Fetching,
                            TaskStatus.Queued)
        return redirect(f"/pipeline/{pipeline_name}")
    return "Unknown URL"
示例#10
0
def show_config(json_data, auth_token=None):
    config = load_config()
    # Explicitly write in config entries we want to return. Do not show anything private.
    result = {
        'homology_url': config['homology_url'],
        'id_mapper_url': config['id_mapper_url'],
        'caching_service_url': config['caching_service_url'],
        'kbase_endpoint': config['kbase_endpoint'],
        'cache_client': config['client']
    }
    return flask.jsonify({
        'version': '1.1',
        'id': json_data.get('id'),
        'result': result
    })
示例#11
0
def _failure_browser_root():
    config = load_config()
    service_types = [
        t for t in [
            ServiceType.Twitter, ServiceType.Pixiv, ServiceType.Fanbox,
            ServiceType.Weibo
        ] if t in config.api
    ]
    stage_list = [
        TaskStage.Fetching, TaskStage.Downloading, TaskStage.Posting,
        TaskStage.Cleaning
    ]
    return render_template('failures_index.jinja2',
                           stage_list=stage_list,
                           service_types=service_types)
示例#12
0
 def test_load_config_with_vmrc(self):
     # This is the case which the config is extended from .vmrc file in current directory.
     copyfile(
         os.path.join(dirname(dirname(__file__)), 'test_resources',
                      _RC_FILE), os.path.join(os.getcwd(), _RC_FILE))
     config = load_config(["test"])
     expected_config = {
         "test": {
             "files": [{
                 "names": ["test_files"]
             }],
             "excludes": ["test_excludes"]
         }
     }
     self.assertEqual(expected_config, config)
     os.remove(os.path.join(os.getcwd(), '.vmrc'))
示例#13
0
def _pipeline(pipeline_name):
    config = load_config()
    pipeline = config.pipeline[pipeline_name]
    subs = []
    for s in pipeline.subscribe:
        l = [(n, subscribe_services[s.service].get_title(n),
              subscribe_services[s.service].get_url(n))
             for n, channels in SubscribeSource.get_subs_by_channel(
                 *s.service, pipeline_name)]
        subss = subscribe_services[s.service]
        options = subss.options()
        subs.append((s.service[0].value, s.service[1], l, len(l), options))
    status = ItemInfo.count_status()
    return render_template('pipeline.jinja2',
                           pipeline_name=pipeline_name,
                           subs=subs,
                           status=status)
示例#14
0
 def test_load_config_with_vmrc(self):
     # This is the case which the config is extended from .vmrc file in current directory.
     copyfile(os.path.join(dirname(dirname(__file__)), 'test_resources', _RC_FILE), os.path.join(os.getcwd(), _RC_FILE))
     config = load_config(["test"])
     expected_config = {
         "test":
             {
                 "files": [
                     {
                         "names": ["test_files"]
                     }
                 ],
                 "excludes": [
                     "test_excludes"
                 ]
             }
     }
     self.assertEqual(expected_config, config)
     os.remove(os.path.join(os.getcwd(), '.vmrc'))
示例#15
0
def update_subs():
    config = load_config()
    for (stype, sfunc), service_type in subscribe_services.items():
        if stype not in config.api:
            continue
        service_conf = list(config.api[stype].values())[0]
        service = service_type(service_conf)
        for name, channels in SubscribeSource.get_subs(stype, sfunc):
            for item in service.subscribe_index(name):
                if not ItemInfo.exists(item.service, item.item_id):
                    ItemInfo.add_index(item, channels)
                    print(stype.value, sfunc, name, item)
                    ItemInfo.set_status(item.service, item.item_id,
                                        TaskStage.Fetching, TaskStatus.Queued)
            for item in service.subscribe_full(name):
                if not ItemInfo.exists(item.service, item.item_id):
                    ItemInfo.add_item(item, channels)
                    print(stype.value, sfunc, name, item)
                    ItemInfo.set_status(item.service, item.item_id,
                                        TaskStage.Downloading,
                                        TaskStatus.Queued)
示例#16
0
def autodownload(ref, save_dir, auth_token):
    """
    Autodownload the fasta/fastq file for a Genome, Reads, or Assembly.
    Args:
      ref is a workspace reference ID in the form 'workspace_id/object_id/version'
      save_dir is the path of a directory in which to save the downloaded file
    Returns a tuple of (file_path, paired_end)
      file_path is the string path of the saved file
      paired_end is a boolean indicating if these are paired-end reads
    The generate_sketch function needs to know if it's working with paired-end reads or not
    """
    config = load_config()
    ws = WorkspaceClient(url=config["kbase_endpoint"], token=auth_token)
    ws_obj = ws.req("get_objects2", {'objects': [{"ref": ref}], 'no_data': 1})

    ws_type = ws_obj['data'][0]['info'][2]
    if valid_types['reads_paired'] in ws_type:
        paths = ws.download_reads_fastq(ref, save_dir)
        output_path = paths[0].replace(".paired.fwd.fastq", ".fastq")
        concatenate_files(paths, output_path)
        print(f'Downloaded fastq file(s) to {output_path}')
        return (output_path, True)
    elif valid_types['reads_single'] in ws_type:
        paths = ws.download_reads_fastq(ref, save_dir)
        output_path = paths[0]
        print(f'Downloaded fastq file(s) to {output_path}')
        return (output_path, False)
    elif valid_types['assembly'] in ws_type or valid_types[
            'assembly_legacy'] in ws_type:
        path = ws.download_assembly_fasta(ref, save_dir)
        print(f'Downloaded fasta file(s) from Assembly to {path}')
        return (path, False)
    elif valid_types['genome'] in ws_type:
        ref = ws.get_assembly_from_genome(ref)
        path = ws.download_assembly_fasta(ref, save_dir)
        print(f'Downloaded fasta file(s) from Genome to {path}')
        return (path, False)
    else:
        raise UnrecognizedWSType(ws_type, valid_types)
示例#17
0
def _prepare_query(query):
    with open(_get_query_path(query)) as f:
        _ = f.readline()  # Skip credentials header line
        query_sql = f.read()

    # Extract query parameters from raw SQL
    params = [p[1:-1] for p in findall(r"{[a-zA-Z_][a-zA-Z_\d]*}", query_sql)]

    if not params:
        return query_sql

    args = load_config("queries")[query]

    missing = list(set(params).difference(args.keys()))
    if missing:
        raise QueryError(f'Query "{query}" missing argument(s): {missing}')

    unused = list(set(args.keys()).difference(params))
    if unused:
        raise QueryError(f'Unused argument(s) for query "{query}": {unused}')

    return query_sql.format(**args)
示例#18
0
def map_refseq_ids_to_kbase(distances):
    """
    Given the results from a request to the AssemblyHomologyService, iterate over every Refseq ID
    in the results and fetch the corresponding KBase ID using the ID Mapping service.
    Args:
      distances - an array of search result objects from the AssemblyHomologyService
    AssemblyHomology API: https://github.com/jgi-kbase/AssemblyHomologyService#api
    """
    # Create a list of Refseq IDs
    config = load_config()
    refseq_ids = [d['sourceid'] for d in distances]
    req_data = {"ids": refseq_ids}
    req_json = json.dumps(req_data)
    endpoint = config['id_mapper_url'] + '/mapping/RefSeq'
    print(f"Getting KBase IDs for {refseq_ids} using endpoint {endpoint}")
    resp = requests.get(endpoint, data=req_json, timeout=999)
    # Handle any error case from the ID Mapper by exiting and logging everything
    if not resp.ok:
        print('=' * 80)
        print(f'ID Mapping error with status code {resp.status_code}')
        print(resp.text)
        print('=' * 80)
        raise Exception(f"Error from the ID Mapping service: {resp.text}")
    resp_json = resp.json()
    print('  response', resp.text)
    # Create a dict of indexes where each key is the refseq ID so we can refer to it below
    indexes = {}
    for (idx, dist) in enumerate(distances):
        indexes[dist['sourceid']] = idx
    # Find all KBase ids in the given mappings
    for (refseq_id, result) in resp_json.items():
        distance_idx = indexes[refseq_id]
        mappings = result['mappings']
        for mapping in mappings:
            if 'KBase' == mapping['ns']:
                kbase_id = mapping['id']
                distances[distance_idx]['kbase_id'] = kbase_id
    return distances
示例#19
0
def generate_sketch(file_path, search_db, paired_end=False):
    """
    Generate a sketch file from a given downloaded fasta/fastq file.
    Args:
      downloaded_file is a DownloadedFile namedtuple defined in ./download_file.py
    Returns the full path of the sketch file
    """
    config = load_config()
    # Fetch the k-mer size
    url = f"{config['homology_url']}/namespace/{search_db}"
    resp = requests.get(url, timeout=999)
    json_resp = resp.json()
    sketch_size = str(json_resp.get('sketchsize', 10000))
    kmer_size = json_resp.get('kmersize', 19)
    if isinstance(kmer_size, list):
        kmer_size = kmer_size[0]
    output_name = os.path.basename(file_path + '.msh')
    output_path = os.path.join(os.path.dirname(file_path), output_name)
    args = ['mash', 'sketch', file_path, '-o', output_path, '-k', str(kmer_size), '-s', sketch_size]
    print(f"Generating sketch with command: {' '.join(args)}")
    if paired_end:
        # For paired end reads, sketch the reads using -m 2 to improve results by ignoring
        # single-copy k-mers, which are more likely to be erroneous.
        # See docs:
        # http://mash.readthedocs.io/en/latest/tutorials.html#querying-read-sets-against-an-existing-refseq-sketch
        args += ['-m', '2']
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)  # nosec
    (stdout, stderr) = proc.communicate()
    print('-' * 80)
    print('mash output:')
    print(stdout)
    print(stderr)
    print('-' * 80)
    if proc.returncode != 0:
        raise Exception(f"Error generating sketch data: {stderr}")
    return output_path
示例#20
0
 def test_create_folder(self):
     config = load_config()
     conf: MegaConfig = config.api[ServiceType.Mega]['default']
     client = mega.Mega()
     client.login(conf.username, conf.password)
     client.create_folder(conf.root)
示例#21
0
        help='Verbose output. Changes log level from INFO to DEBUG.')
    parser.add_argument(
        '--config',
        help='Specify a configuration file (defaults to ./config.yml)')
    parser.add_argument('-l',
                        '--logfile',
                        help="Log file to append logs to.",
                        default=None)
    args = parser.parse_args()

    logging.basicConfig(level=logging.DEBUG if args.v else logging.INFO,
                        filename=args.logfile,
                        format="%(asctime)-15s: %(message)s")
    enable_color_logging(debug_lvl=logging.DEBUG if args.v else logging.INFO)
    logger.info(intro())
    CONFIG = load_config(args.config or "./config.yml")
    li = lichess.Lichess(CONFIG["token"], CONFIG["url"], __version__)

    user_profile = li.get_profile()
    username = user_profile["username"]
    is_bot = user_profile.get("title") == "BOT"
    logger.info("Welcome {}!".format(username))

    if args.u is True and is_bot is False:
        is_bot = upgrade_account(li)

    if is_bot:
        engine_factory = partial(engine_wrapper.create_engine, CONFIG)
        start(li, user_profile, engine_factory, CONFIG)
    else:
        logger.error(
示例#22
0
    def run(self):
        # read configuration files
        config_root = self.ENV.get("CONFIG_ROOT", "../config")
        config_schemas = config.load_config(config_root + "/schemas.yml",
                                            env=self.ENV)
        config_services = config.load_config(config_root + "/services.yml",
                                             env=self.ENV)
        app_config["config_root"] = config_root
        app_config["config_services"] = config_services

        agent_admin_url = self.ENV.get("AGENT_ADMIN_URL")
        if not agent_admin_url:
            raise RuntimeError(
                "Error AGENT_ADMIN_URL is not specified, can't connect to Agent."
            )
        app_config["AGENT_ADMIN_URL"] = agent_admin_url

        # get public DID from our agent
        response = requests.get(
            agent_admin_url + "/wallet/did/public",
            headers=ADMIN_REQUEST_HEADERS,
        )
        result = response.json()
        did = result["result"]
        LOGGER.info("Fetched DID from agent: %s", did)
        app_config["DID"] = did["did"]

        # determine pre-registered schemas and cred defs
        existing_schemas = agent_schemas_cred_defs(agent_admin_url)

        # register schemas and credential definitions
        for schema in config_schemas:
            schema_name = schema["name"]
            schema_version = schema["version"]
            schema_key = schema_name + "::" + schema_version
            if schema_key not in existing_schemas:
                schema_attrs = []
                schema_descs = {}
                if isinstance(schema["attributes"], dict):
                    # each element is a dict
                    for attr, desc in schema["attributes"].items():
                        schema_attrs.append(attr)
                        schema_descs[attr] = desc
                else:
                    # assume it's an array
                    for attr in schema["attributes"]:
                        schema_attrs.append(attr)

                # register our schema(s) and credential definition(s)
                schema_request = {
                    "schema_name": schema_name,
                    "schema_version": schema_version,
                    "attributes": schema_attrs,
                }
                response = agent_post_with_retry(
                    agent_admin_url + "/schemas",
                    json.dumps(schema_request),
                    headers=ADMIN_REQUEST_HEADERS,
                )
                response.raise_for_status()
                schema_id = response.json()
            else:
                schema_id = {
                    "schema_id": existing_schemas[schema_key]["schema"]["id"]
                }
            app_config["schemas"]["SCHEMA_" + schema_name] = schema
            app_config["schemas"]["SCHEMA_" + schema_name + "_" +
                                  schema_version] = schema_id["schema_id"]
            LOGGER.info("Registered schema: %s", schema_id)

            if (schema_key not in existing_schemas
                    or "cred_def" not in existing_schemas[schema_key]):
                cred_def_request = {"schema_id": schema_id["schema_id"]}
                response = agent_post_with_retry(
                    agent_admin_url + "/credential-definitions",
                    json.dumps(cred_def_request),
                    headers=ADMIN_REQUEST_HEADERS,
                )
                response.raise_for_status()
                credential_definition_id = response.json()
            else:
                credential_definition_id = {
                    "credential_definition_id":
                    existing_schemas[schema_key]["cred_def"]["id"]
                }
            app_config["schemas"]["CRED_DEF_" + schema_name + "_" +
                                  schema_version] = credential_definition_id[
                                      "credential_definition_id"]
            LOGGER.info("Registered credential definition: %s",
                        credential_definition_id)

        # what is the TOB connection name?
        tob_connection_params = config_services["verifiers"]["bctob"]

        # check if we have a TOB connection
        response = requests.get(
            agent_admin_url + "/connections?alias=" +
            tob_connection_params["alias"],
            headers=ADMIN_REQUEST_HEADERS,
        )
        response.raise_for_status()
        connections = response.json()["results"]
        tob_connection = None
        for connection in connections:
            # check for TOB connection
            if connection["alias"] == tob_connection_params["alias"]:
                tob_connection = connection

        if not tob_connection:
            # if no tob connection then establish one (if we can)
            # (agent_admin_url is provided if we can directly ask the TOB agent for an invitation,
            #   ... otherwise the invitation has to be provided manually through the admin api
            #   ... WITH THE CORRECT ALIAS)
            if ("agent_admin_url" in tob_connection_params["connection"] and
                    tob_connection_params["connection"]["agent_admin_url"]):
                tob_agent_admin_url = tob_connection_params["connection"][
                    "agent_admin_url"]
                response = requests.post(
                    tob_agent_admin_url + "/connections/create-invitation",
                    headers=TOB_REQUEST_HEADERS,
                )
                response.raise_for_status()
                invitation = response.json()

                response = requests.post(
                    agent_admin_url +
                    "/connections/receive-invitation?alias=" +
                    tob_connection_params["alias"],
                    json.dumps(invitation["invitation"]),
                    headers=ADMIN_REQUEST_HEADERS,
                )
                response.raise_for_status()
                tob_connection = response.json()

                LOGGER.info("Established tob connection: %s",
                            json.dumps(tob_connection))
                time.sleep(5)

        # if we have a connection to the TOB agent, we can register our issuer
        if tob_connection:
            register_issuer_with_orgbook(tob_connection["connection_id"])
        else:
            print(
                "No TOB connection found or established, awaiting invitation to connect to TOB ..."
            )
示例#23
0
 def test_load_config(self):
     config = load_config(["default"])
     self.assertDictEqual({k: v for k, v in _INITIAL_CONFIG["groups"].items() if k is not "default_groups"}, config)
示例#24
0
import traceback
from functools import lru_cache
from io import BytesIO

from src.config import load_config
from src.enums import TaskStatus, TaskStage, ServiceType
from src.models.connect import connect_db
from src.models.item import ItemInfo, SecondaryTask
from src.services import push_services, pull_services

config = load_config()


def post_images():
    for item in ItemInfo.poll_status(TaskStage.Posting, TaskStatus.Queued):
        channels = ItemInfo.get_channels(item)
        for ch in channels:
            for pipe in config.pipeline[ch].push:
                SecondaryTask.add_task(item.service, item.item_id, pipe.service, pipe.config, ch)
        ItemInfo.set_status(item.service, item.item_id, TaskStage.Posting, TaskStatus.Pending)
    for stype, item_id, ptype, conf, ch, poll_counter in SecondaryTask.poll_tasks(20):
        SecondaryTask.acquire_task(stype, item_id, ptype, conf, ch)
        print((stype.value, item_id), '=>', (ptype.value, conf))
        item = ItemInfo.get_item(stype, item_id)
        if not service_exists(ptype, conf):
            continue
        client = get_service(ptype, conf)
        if poll_counter >= client.push_limit():
            print("Failed to push item.")
            SecondaryTask.close_task(stype, item_id, ptype, conf, ch)
        else:
示例#25
0
def launch():
    conf = load_config().server
    connect_db()
    app.run(host=conf.host, port=conf.port)
示例#26
0
import os
import sys
from logging.config import fileConfig
from sqlalchemy import create_engine
from sqlalchemy import pool

from alembic import context

from src.config import load_config
from src.infra.db.session import metadata

CONFIG = load_config()

parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(parent_dir)

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
from src.infra.db.mapper import start_mappers

start_mappers()
target_metadata = metadata
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
示例#27
0
def get_cache_id(data):
    # Generate the cache_id
    config = load_config()
    cache_id = config['cache_client'].generate_cacheid(data)
    print(f'Fetched cache ID {cache_id[0:16]}..')
    return cache_id
示例#28
0
def _index():
    pipelines = load_config().pipeline.items()
    status = ItemInfo.count_status()
    return render_template('index.jinja2', pipelines=pipelines, status=status)
示例#29
0
def upload_to_cache(cache_id, string):
    """Save string content to a cache."""
    config = load_config()
    print(f'Uploading contents to cache {cache_id[0:16]}..')
    config['cache_client'].upload_cache(cache_id, string=string)
示例#30
0
 def test_write(self):
     config = load_config()
     conf: MegaConfig = config.api[ServiceType.Mega]['default']
     client = MegaService(conf)
     print(client.ensure_dir(Path("/xnh/b/c/d/e")))
     print(client.client.find("/xnh/b/c/d/e", exclude_deleted=True))
示例#31
0
 def __init__(self, groups):
     self._loaded_files = []
     self.config = load_config(groups)
示例#32
0
# Arguments
parser = argparse.ArgumentParser(
    description='Train a 3D reconstruction model.')
parser.add_argument('config', type=str, help='Path to config file.')
#parser.add_argument('model_name', type=str, default='model', help='Model output file, i.e. for foo.pt insert foo')
parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
parser.add_argument(
    '--exit-after',
    type=int,
    default=-1,
    help='Checkpoint and exit after specified number of seconds'
    'with exit code 2.')

args = parser.parse_args()
cfg = config.load_config(args.config, 'configs/default.yaml')
is_cuda = (torch.cuda.is_available() and not args.no_cuda)
device = torch.device("cuda" if is_cuda else "cpu")
#DEGREES = cfg['degrees']
DEGREES = 1

#model_name = args.model_name
##
# Set t0
t0 = time.time()

# Shorthands
out_dir = cfg['training']['out_dir']
batch_size = cfg['training']['batch_size']
backup_every = cfg['training']['backup_every']
vis_n_outputs = cfg['generation']['vis_n_outputs']
示例#33
0
 def __init__(self, groups):
     self._loaded_files = []
     self.config = load_config(groups)