def get_auth_params(conf: ConfigTree, discover_auth: bool = False) -> Dict[str, Any]: if discover_auth: # Mode discovery API needs custom token set in header # https://mode.com/developer/discovery-api/introduction/ params = { "headers": { "Authorization": conf.get_string(MODE_BEARER_TOKEN), } } # type: Dict[str, Any] else: params = { 'auth': HTTPBasicAuth(conf.get_string(MODE_ACCESS_TOKEN), conf.get_string(MODE_PASSWORD_TOKEN) ) } return params
def logging_setup(name: str, config: ConfigTree, set_default: bool = False): """ initialize a logger based on the specified configuration :param name: the name of the logger (use the empty string to configure the root logger) :param config: the configuration object :param set_default: set this configuration for the default (root) logger aswell """ # read config log_conf = config.get_config('logging') format = log_conf.get('format', '{asctime} [{levelname}]: {message}') stdout = log_conf.get_bool('console.enabled', True) stdout_level = logging._nameToLevel[log_conf.get( 'console.level', 'info').upper()] logfile = log_conf.get_bool('logfile.enabled', False) logfile_level = logging._nameToLevel[log_conf.get( 'logfile.level', 'info').upper()] logfile_path = log_conf.get('logfile.path') # setup logging if set_default: util.logging_setup('', format=format, stdout=stdout, stdout_level=stdout_level) util.logging_setup(name, format=format, stdout=stdout, stdout_level=stdout_level, logfile=logfile, logfile_level=logfile_level, logfile_path=logfile_path)
def clone_task(cls, cloned_task_id, name, comment=None, execution_overrides=None, tags=None, parent=None, project=None, log=None, session=None): """ Clone a task :param session: Session object used for sending requests to the API :type session: Session :param cloned_task_id: Task ID for the task to be cloned :type cloned_task_id: str :param name: New for the new task :type name: str :param comment: Optional comment for the new task :type comment: str :param execution_overrides: Task execution overrides. Applied over the cloned task's execution section, useful for overriding values in the cloned task. :type execution_overrides: dict :param tags: Optional updated model tags :type tags: [str] :param parent: Optional parent ID of the new task. :type parent: str :param project: Optional project ID of the new task. If None, the new task will inherit the cloned task's project. :type parent: str :param log: Log object used by the infrastructure. :type log: logging.Logger :return: The new tasks's ID """ session = session if session else cls._get_default_session() res = cls._send(session=session, log=log, req=tasks.GetByIdRequest(task=cloned_task_id)) task = res.response.task output_dest = None if task.output: output_dest = task.output.destination execution = task.execution.to_dict() if task.execution else {} execution = ConfigTree.merge_configs( ConfigFactory.from_dict(execution), ConfigFactory.from_dict(execution_overrides or {})) req = tasks.CreateRequest(name=name, type=task.type, input=task.input, tags=tags if tags is not None else task.tags, comment=comment or task.comment, parent=parent, project=project if project else task.project, output_dest=output_dest, execution=execution.as_plain_ordered_dict(), script=task.script) res = cls._send(session=session, log=log, req=req) return res.response.id
def init(self, conf: ConfigTree): self.conf = conf.with_fallback(WhaleLoader.DEFAULT_CONFIG) self.base_directory = self.conf.get_string("base_directory") self.tmp_manifest_path = self.conf.get_string("tmp_manifest_path", None) self.database_name = self.conf.get_string("database_name", None) Path(self.base_directory).mkdir(parents=True, exist_ok=True) Path(paths.MANIFEST_DIR).mkdir(parents=True, exist_ok=True)
def init(self, conf: ConfigTree): """ Establish connection, import data model class (if provided) :param conf: configuration file. """ self.conn_string = conf.get_string(SQLAlchemyEngine.CONN_STRING_KEY) self.credentials_path = conf.get(SQLAlchemyEngine.CREDENTIALS_PATH_KEY, None) self.connection = self._get_connection() model_class = conf.get(SQLAlchemyEngine.MODEL_CLASS_KEY, None) if model_class: module_name, class_name = model_class.rsplit(".", 1) mod = importlib.import_module(module_name) self.model_class = getattr(mod, class_name)
def init(self, conf: ConfigTree) -> None: self.conf = conf.with_fallback( SpliceMachineMetadataExtractor.DEFAULT_CONFIG) self._database = self.conf.get_string( SpliceMachineMetadataExtractor.DATABASE_KEY) self._cluster = self.conf.get_string( SpliceMachineMetadataExtractor.CLUSTER_KEY) self._where_clause_suffix = self.conf.get_string( SpliceMachineMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY) self._username = self.conf.get_string( SpliceMachineMetadataExtractor.USERNAME_KEY) self._password = self.conf.get_string( SpliceMachineMetadataExtractor.PASSWORD_KEY) self._host = self.conf.get_string( SpliceMachineMetadataExtractor.HOST_KEY) context = { "where_clause_suffix": self._where_clause_suffix, } j2_env = Environment(loader=FileSystemLoader( os.path.dirname(os.path.abspath(__file__))), trim_blocks=True) self.sql_statement = j2_env.get_template( 'splice_machine_metadata_extractor.sql').render(context) LOGGER.info("SQL for splicemachine: {}".format(self.sql_statement)) self._extract_iter = None self.connection = splice_connect(self._username, self._password, self._host) self.cursor = self.connection.cursor()
def project(argv=sys.argv): global instance if instance is not None: return instance else: pattern = re.compile('-D(.*)=(.*)') conf_override = dict() argv_filtered = [] for a in argv: m = pattern.match(a) if m is not None: conf_override[m.group(1)] = m.group(2) else: argv_filtered.append(a) parser = ArgumentParser() parser.add_argument('--conf', default='application.conf') args, other = parser.parse_known_args(argv_filtered) conf = ConfigFactory.parse_file(args.conf) conf_override = ConfigFactory.from_dict(conf_override) conf_merged = ConfigTree.merge_configs(conf, conf_override) instance = Project(conf=conf_merged) return instance
def __init__( self, langs: List[str], added_nodes: Dict[str, Dict[str, str]], skip_node_types: Dict[str, List[str]], vendors_path: Path = Path("./vendor"), ): super(TreeSitterParser, self).__init__() vendors = [] self.added_nodes = added_nodes self.skip_node_types = skip_node_types for lang in langs: vendors.append(vendors_path / f"tree-sitter-{lang}") if lang not in added_nodes: self.added_nodes[lang] = ConfigTree([("prefix", ""), ("suffix", "")]) if lang not in skip_node_types: self.skip_node_types[lang] = [] Language.build_library( # Store the library in the `build` directory "build/my-languages.so", # Include one or more languages vendors, ) self.parser = Parser()
def test_update_deployments_should_catch_exceptions(self, m): mist = MistApp(validate=False) context = models.Context('test-context') fn = models.Function('test-fn', 'Test', 'test-context', 'test-path.py') mist.update_function = MagicMock(return_value=fn) mist.update_context = MagicMock(return_value=context) mist.context_parser.parse = MagicMock(return_value=context) mist.function_parser.parse = MagicMock(return_value=fn) depls = [ models.Deployment('simple', 'Function', ConfigTree()), models.Deployment('simple-ctx', 'Context', ConfigTree()) ] mist.update_deployments(depls)
def parse(self, name, cfg): """ :type name str :param name: :type cfg pyhocon.config_tree.ConfigTree :param cfg: :return: """ def parse_spark_config(value, key_prefix): if isinstance(value, ConfigTree): res = dict() for k in value.keys(): v = value[k] new_key = k if key_prefix == '' else key_prefix + '.' + k res.update(parse_spark_config(v, new_key)) return res else: return {key_prefix: str(value)} return Context( name, cfg.get_int('max-parallel-jobs', 20), cfg.get_string('downtime', '120s'), parse_spark_config(cfg.get_config('spark-conf', ConfigTree()), ''), cfg.get_string('worker-mode', 'shared'), cfg.get_string('run-options', ''), cfg.get_bool('precreated', False), cfg.get_string('streaming-duration', '1s'))
def init(self, conf: ConfigTree) -> None: self._conf = conf.with_fallback( HiveTableLastUpdatedExtractor.DEFAULT_CONFIG) pool_size = self._conf.get_int( HiveTableLastUpdatedExtractor.FS_WORKER_POOL_SIZE) LOGGER.info('Using thread pool size: {}'.format(pool_size)) self._fs_worker_pool = ThreadPool(processes=pool_size) self._fs_worker_timeout = self._conf.get_int( HiveTableLastUpdatedExtractor.FS_WORKER_TIMEOUT_SEC) LOGGER.info('Using thread timeout: {} seconds'.format( self._fs_worker_timeout)) self._cluster = '{}'.format( self._conf.get_string(HiveTableLastUpdatedExtractor.CLUSTER_KEY)) self._partitioned_table_extractor = self._get_partitioned_table_sql_alchemy_extractor( ) self._non_partitioned_table_extractor = self._get_non_partitioned_table_sql_alchemy_extractor( ) self._fs = self._get_filesystem() self._last_updated_filecheck_threshold \ = self._conf.get_int(HiveTableLastUpdatedExtractor.FILE_CHECK_THRESHOLD) self._extract_iter: Union[None, Iterator] = None
def __init__(self, cfg: ConfigTree): self.cfg = cfg print(cfg) self.summary_writer = SummaryWriter(log_dir=experiment_path) self.model_builder = ModelFactory(cfg) self.dataset_builder = DataLoaderFactory(cfg) self.train_ds = self.dataset_builder.build(split='train') self.test_ds = self.dataset_builder.build(split='val') self.ds: YoutubeDataset = self.train_ds.dataset self.train_criterion = nn.CrossEntropyLoss( ignore_index=self.ds.PAD_IDX) self.val_criterion = nn.CrossEntropyLoss(ignore_index=self.ds.PAD_IDX) self.model: nn.Module = self.model_builder.build( device=torch.device('cuda'), wrapper=nn.DataParallel) optimizer = optim.Adam(self.model.parameters(), lr=0., betas=(0.9, 0.98), eps=1e-9) self.optimizer = CustomSchedule( self.cfg.get_int('model.emb_dim'), optimizer=optimizer, ) self.num_epochs = cfg.get_int('num_epochs') logger.info(f'Use control: {self.ds.use_control}')
def _read_single_file(file_path, verbose=True): if not file_path or not Path(file_path).is_file(): return ConfigTree() if verbose: print("Loading config from file %s" % file_path) try: return pyhocon.ConfigFactory.parse_file(file_path) except ParseSyntaxException as ex: msg = "Failed parsing {0} ({1.__class__.__name__}): (at char {1.loc}, line:{1.lineno}, col:{1.column})".format( file_path, ex) six.reraise( ConfigurationError, ConfigurationError(msg, file_path=file_path), sys.exc_info()[2], ) except (ParseException, ParseFatalException, RecursiveGrammarException) as ex: msg = "Failed parsing {0} ({1.__class__.__name__}): {1}".format( file_path, ex) six.reraise(ConfigurationError, ConfigurationError(msg), sys.exc_info()[2]) except Exception as ex: print("Failed loading %s: %s" % (file_path, ex)) raise
def init(self, conf: ConfigTree) -> None: """ :param conf: """ self.conf = conf self.column_lineage_file_location = conf.get_string( CsvColumnLineageExtractor.COLUMN_LINEAGE_FILE_LOCATION) self._load_csv()
def _read_recursive_for_env(self, root_path_str, env, verbose=True): root_path = Path(root_path_str) if root_path.exists(): default_config = self._read_recursive( root_path / Environment.default, verbose=verbose ) if (root_path / env) != (root_path / Environment.default): env_config = self._read_recursive( root_path / env, verbose=verbose ) # None is ok, will return empty config config = ConfigTree.merge_configs(default_config, env_config, True) else: config = default_config else: config = ConfigTree() return config
def init(self, conf: ConfigTree) -> None: conf = conf.with_fallback(PrestoViewMetadataExtractor.DEFAULT_CONFIG) self._cluster = '{}'.format( conf.get_string(PrestoViewMetadataExtractor.CLUSTER_KEY)) self.sql_stmt = PrestoViewMetadataExtractor.SQL_STATEMENT.format( where_clause_suffix=conf.get_string( PrestoViewMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY)) LOGGER.info('SQL for hive metastore: {}'.format(self.sql_stmt)) self._alchemy_extractor = SQLAlchemyExtractor() sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\ .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) self._alchemy_extractor.init(sql_alch_conf) self._extract_iter: Union[None, Iterator] = None
def _output_from_config(self, config: ConfigTree) -> None: if config is None: self._logger.info('No output provided in configuration, defaulting to Console') self._output = ConsoleOutput() else: from weather.outputs import create_output builder = create_output(config['name']) self._output = builder.from_config(ConfigTree() if 'args' not in config else config['args'])
def init(self, conf: ConfigTree) -> None: """ Establish connections and import data model class if provided :param conf: """ self.conf = conf.with_fallback(Neo4jExtractor.DEFAULT_CONFIG) self.graph_url = conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) self.cypher_query = conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY) self.driver = self._get_driver() self._extract_iter: Union[None, Iterator] = None model_class = conf.get(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY, None) if model_class: module_name, class_name = model_class.rsplit(".", 1) mod = importlib.import_module(module_name) self.model_class = getattr(mod, class_name)
def init(self, conf: ConfigTree) -> None: conf = conf.with_fallback(AthenaMetadataExtractor.DEFAULT_CONFIG) self._cluster = conf.get_string(AthenaMetadataExtractor.CATALOG_KEY) self.sql_stmt = AthenaMetadataExtractor.SQL_STATEMENT.format( where_clause_suffix=conf.get_string( AthenaMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), catalog_source=self._cluster) LOGGER.info('SQL for Athena metadata: %s', self.sql_stmt) self._alchemy_extractor = SQLAlchemyExtractor() sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\ .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) self._alchemy_extractor.init(sql_alch_conf) self._extract_iter: Union[None, Iterator] = None
def init(self, conf: ConfigTree) -> None: """ Establish connections and import data model class if provided :param conf: """ self.conf = conf self.conn_string = conf.get_string(SQLAlchemyExtractor.CONN_STRING) self.connection = self._get_connection() self.extract_sql = conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) model_class = conf.get('model_class', None) if model_class: module_name, class_name = model_class.rsplit(".", 1) mod = importlib.import_module(module_name) self.model_class = getattr(mod, class_name) self._execute_query()
def init(self, conf: ConfigTree) -> None: self.success_count = 0 self.failure_count = 0 parsing_function = conf.get_string(PARSING_FUNCTION) module_name, function_name = parsing_function.rsplit(".", 1) mod = importlib.import_module(module_name) self._parsing_function = getattr(mod, function_name)
def init(self, conf: ConfigTree) -> None: """ Establish connections and import data model class if provided :param conf: """ self.conf = conf.with_fallback(Neo4jExtractor.DEFAULT_CONFIG) self.graph_url = conf.get_string("graph_url") self.included_keys = conf.get("included_keys", None) self.excluded_keys = conf.get("excluded_keys", None) self.included_key_regex = conf.get("included_key_regex", None) self.excluded_key_regex = conf.get("excluded_key_regex", None) # Add where clause based on configuration inputs. keys = ["schema.key", "db.key", "cluster.key", "table.key"] or_where_clauses = [] and_where_clauses = [] if self.included_keys is not None: for key in keys: or_where_clauses.append("{} IN {}".format( key, self.included_keys)) if self.excluded_keys is not None: for key in keys: and_where_clauses.append("{} NOT IN {}".format( key, self.excluded_keys)) if self.included_key_regex is not None: for key in keys: or_where_clauses.append("{} =~ '{}'".format( key, self.included_key_regex)) if self.excluded_key_regex is not None: for key in keys: and_where_clauses.append("NOT {} =~ '{}'".format( key, self.excluded_key_regex)) where_clause = combine_where_clauses(and_clauses=and_where_clauses, or_clauses=or_where_clauses) self.cypher_query = AmundsenNeo4jMetadataExtractor.CYPHER_QUERY.format( where_clause=where_clause) self.driver = self._get_driver() self._extract_iter = None
def merge(cls, configs): for c in configs: assert isinstance(c, Config) ctree = configs[0]._config_tree for c in configs[1:]: ctree = ConfigTree.merge_configs(ctree, c._config_tree) return cls(ctree)
def __init__(self, config_tree=None): """Create a Config. Args: config_tree (ConfigTree) """ if config_tree is None: config_tree = ConfigTree() self._config_tree = config_tree
def _load_custom_config(run_config): """Load custom configuration input HOCON file for cromwell. """ from pyhocon import ConfigFactory, HOCONConverter, ConfigTree conf = ConfigFactory.parse_file(run_config) out = {} if "database" in conf: out["database"] = HOCONConverter.to_hocon(ConfigTree({"database": conf.get_config("database")})) return out
def test_with_user_name(self): d = Deployment('test', 'Function', ConfigTree({ 'context': 'foo', 'path': 'test-name.jar' }), '0.0.1') d.with_user('test_name') self.assertEqual(d.name, 'test_name_test') self.assertEqual(d.data['path'], 'test_name_test-name.jar') self.assertEqual(d.data['context'], 'test_name_foo')
def init(self, conf: ConfigTree) -> None: self._progress_report_frequency = \ conf.get_int('{}.{}'.format(self.get_scope(), DefaultTask.PROGRESS_REPORT_FREQUENCY), 500) self.extractor.init( Scoped.get_scoped_conf(conf, self.extractor.get_scope())) self.transformer.init( Scoped.get_scoped_conf(conf, self.transformer.get_scope())) self.loader.init(Scoped.get_scoped_conf(conf, self.loader.get_scope()))
def load_default_config(path: Union[Path, str] = None): """Load custom configuration from specified file. Modifies global constants!""" # pylint: disable=global-statement global DEFAULT_CONFIG if path: DEFAULT_CONFIG = ConfigTree.merge_configs( DEFAULT_CONFIG, ConfigFactory.parse_file(str(path))) else: DEFAULT_CONFIG = default_config(RESOURCES_PATH, "graph.conf")
def get_config_by_env(configs, current_env): try: common_configs = configs.get('common', ConfigTree()) env_configs = configs.get(current_env) merge_configs = Configger.merge_configs([common_configs, env_configs]) return merge_configs except Exception: logger.error(traceback.format_exc()) logger.error("Fetch config by {} ENV error.".format(current_env))
def from_cfg(cls, cfg: ConfigTree, split='train'): return cls(cfg.get_string('dataset.split_csv_dir'), cfg.get_config('dataset.streams'), split=split, duration=cfg.get_float('dataset.duration'), duplication=cfg.get_int('dataset.duplication'), fps=cfg.get_float('dataset.fps'), events_per_sec=cfg.get_int('dataset.events_per_sec'), random_shift_rate=cfg.get_float('dataset.random_shift_rate', 0.2), pose_layout=cfg.get_string('dataset.pose_layout'))
def test_self_merge_ref_substitutions_object(self): config1 = ConfigFactory.parse_string( """ a : { } b : 1 c : ${a} { d : [ ${b} ] } """, resolve=False ) config2 = ConfigFactory.parse_string( """ e : ${a} { } """, resolve=False ) merged = ConfigTree.merge_configs(config1, config2) resolved = ConfigParser.resolve_substitutions(merged) assert resolved.get("c.d") == [1]
def test_self_merge_ref_substitutions_object3(self): config1 = ConfigFactory.parse_string( """ b1 : { v1: 1 } b = [${b1}] """, resolve=False ) config2 = ConfigFactory.parse_string( """ b1 : { v1: 2, v2: 3 } """, resolve=False ) merged = ConfigTree.merge_configs(config1, config2) resolved = ConfigParser.resolve_substitutions(merged) assert resolved.get("b1") == {"v1": 2, "v2": 3} b = resolved.get("b") assert len(b) == 1 assert b[0] == {"v1": 2, "v2": 3}
def test_self_merge_ref_substitutions_object2(self): config1 = ConfigFactory.parse_string( """ x : { v1: 1 } b1 : {v2: 2 } b = [${b1}] """, resolve=False ) config2 = ConfigFactory.parse_string( """ b2 : ${x} {v2: 3} b += [${b2}] """, resolve=False ) merged = ConfigTree.merge_configs(config1, config2) resolved = ConfigParser.resolve_substitutions(merged) b = resolved.get("b") assert len(b) == 2 assert b[0] == {'v2': 2} assert b[1] == {'v1': 1, 'v2': 3}