def __call__(self): # read the configuration out of S3 with open(local_cfg_file_location) as data_file: config = json.load(data_file) # setup the environment variables for var in config['variables']: Variable.set(var['name'],var['value'])
def test_var_with_encryption_rotate_fernet_key(self, mock_get): """ Tests rotating encrypted variables. """ key1 = Fernet.generate_key() key2 = Fernet.generate_key() mock_get.return_value = key1.decode() Variable.set('key', 'value') session = settings.Session() test_var = session.query(Variable).filter(Variable.key == 'key').one() self.assertTrue(test_var.is_encrypted) self.assertEqual(test_var.val, 'value') self.assertEqual(Fernet(key1).decrypt(test_var._val.encode()), b'value') # Test decrypt of old value with new key mock_get.return_value = ','.join([key2.decode(), key1.decode()]) crypto._fernet = None self.assertEqual(test_var.val, 'value') # Test decrypt of new value with new key test_var.rotate_fernet_key() self.assertTrue(test_var.is_encrypted) self.assertEqual(test_var.val, 'value') self.assertEqual(Fernet(key2).decrypt(test_var._val.encode()), b'value')
def variables(args): if args.get: try: var = Variable.get(args.get, deserialize_json=args.json, default_var=args.default) print(var) except ValueError as e: print(e) if args.delete: session = settings.Session() session.query(Variable).filter_by(key=args.delete).delete() session.commit() session.close() if args.set: Variable.set(args.set[0], args.set[1]) # Work around 'import' as a reserved keyword imp = getattr(args, 'import') if imp: if os.path.exists(imp): import_helper(imp) else: print("Missing variables file.") if args.export: export_helper(args.export) if not (args.set or args.get or imp or args.export or args.delete): # list all variables session = settings.Session() vars = session.query(Variable) msg = "\n".join(var.key for var in vars) print(msg)
def test_variable_with_encryption(self): """ Test variables with encryption """ Variable.set('key', 'value') session = settings.Session() test_var = session.query(Variable).filter(Variable.key == 'key').one() self.assertTrue(test_var.is_encrypted) self.assertEqual(test_var.val, 'value')
def register_issues(ds, **kwargs): """Registra ou atualiza todos os fascículos a partir do Kernel. Fascículos de periódicos não encontrados são marcados como órfãos e armazenados em uma variável persistente para futuras tentativas. """ tasks = kwargs["ti"].xcom_pull(key="tasks", task_ids="read_changes_task") known_issues = kwargs["ti"].xcom_pull(key="known_issues", task_ids="register_journals_task") journals_aop = kwargs["ti"].xcom_pull(key="journals_aop", task_ids="register_journals_task") def _journal_id(issue_id): """Obtém o identificador do periódico onde `issue_id` está contido.""" for journal_id, issues in known_issues.items(): for issue in issues: if issue_id == issue["id"]: return journal_id def _issue_order(issue_id): """A posição em relação aos demais fascículos do periódico. Pode levantar `ValueError` caso `issue_id` não conste na relação de fascículos do periódico `journal_id`. """ issues = known_issues.get(_journal_id(issue_id), []) for issue in issues: if issue_id == issue["id"]: return issue["order"] def _journal_aop_id(aop_id): """Obtém o identificador do periódico a partir da lista de AOPs.""" return journals_aop[aop_id] issues_to_get = itertools.chain( Variable.get("orphan_issues", default_var=[], deserialize_json=True), (get_id(task["id"]) for task in filter_changes(tasks, "bundles", "get")), ) # Cadastra os AOPs # No caso dos aops não é obrigatório o atributo order orphans, known_documents = try_register_issues(journals_aop.keys(), _journal_aop_id, None, fetch_bundles, IssueFactory, True) # Cadastra os fascículos regulares orphans, known_documents = try_register_issues(issues_to_get, _journal_id, _issue_order, fetch_bundles, IssueFactory) kwargs["ti"].xcom_push(key="i_documents", value=known_documents) Variable.set("orphan_issues", orphans, serialize_json=True) return tasks
def get_teams_schema(ds, **kwargs): """ This task lists the last config for every team. Then a Neo4j query is done to count the nodes of each label. """ with kwargs["params"]["app"].app_context(): from depc.controllers.configs import ConfigController # Get all configs ordered by -date configs = ConfigController._list(order_by="updated_at", reverse=True) # Get the last config by team teams = {} for config in configs: team = config.team # For each team if team.kafka_topic not in teams.keys(): logger.info("[{0}] Configuration : {1}".format(team.name, config.data)) data = { "id": str(team.id), "name": team.name, "topic": team.kafka_topic, "schema": config.data, "labels": {}, } # Count number of nodes per label logger.info( "[{0}] Counting nodes for {1} labels...".format( team.name, len(config.data.keys()) ) ) for label in config.data.keys(): neo_key = "{}_{}".format(team.kafka_topic, label) records = get_records( "MATCH (n:{label}) RETURN count(n) AS Count".format( label=neo_key ) ) count = list(records)[0].get("Count") logger.info( "[{0}] {1} nodes for label {2}...".format( team.name, count, label ) ) data["labels"][label] = count teams[team.kafka_topic] = data # Save the config into an Airflow variable Variable.set("config", list(teams.values()), serialize_json=True)
def test_variable_with_encryption(self, mock_get): """ Test variables with encryption """ mock_get.return_value = Fernet.generate_key().decode() Variable.set('key', 'value') session = settings.Session() test_var = session.query(Variable).filter(Variable.key == 'key').one() self.assertTrue(test_var.is_encrypted) self.assertEqual(test_var.val, 'value')
def test_variable_no_encryption(self, mock_get): """ Test variables without encryption """ mock_get.return_value = '' Variable.set('key', 'value') session = settings.Session() test_var = session.query(Variable).filter(Variable.key == 'key').one() self.assertFalse(test_var.is_encrypted) self.assertEqual(test_var.val, 'value')
def set_vars_from_env(): load_dotenv() Variable.set('FITBIT_ACCESS', os.environ.get('FITBIT_ACCESS')) Variable.set('FITBIT_APP_TOKEN', os.environ.get('FITBIT_APP_TOKEN')) Variable.set('FITBIT_REFRESH', os.environ.get('FITBIT_REFRESH')) Variable.set('LOCAL_STAGING', os.environ.get('LOCAL_STAGING')) Variable.set('WEATHERBIT_KEY', os.environ.get('WEATHERBIT_KEY')) print('Airflow variables set')
def set_params(*argv, **kwargs): # In the real version these will be set from kwargs['dag_run'].conf dict run_id = kwargs['run_id'] ingest_id = run_id dag_params = Variable.get("dag_params", deserialize_json=True) new_dag_params = kwargs['params'].copy() new_dag_params['run_id'] = run_id new_dag_params['ingest_id'] = ingest_id Variable.set("dag_params", new_dag_params, serialize_json=True) return 'Whatever you return gets printed in the logs'
def test_should_respond_200(self): expected_value = '{"foo": 1}' Variable.set("TEST_VARIABLE_KEY", expected_value) response = self.client.get("/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 assert response.json == { "key": "TEST_VARIABLE_KEY", "value": expected_value }
def execute(self, context): oracle = OracleHelper(self.oracle_conn_id) self.log.info( f"Executing SQL:{self.sql_count_id}\nParameters: {self.dict_bind}") count_id = oracle.get_rows_with_bind(sql=self.sql_count_id, bind=self.dict_bind)[0][0] Variable.set(key=f'{self.current_dag_name}_total_row_id', value=count_id) self.log.info(f"{count_id} rows are not in HDFS.")
def execute(self, context): client = boto3.client('s3') # aws_access_key_id=self.ACCESS_KEY, # aws_secret_access_key=self.SECRET_KEY, # aws_session_token=self.SESSION_TOKEN) multipart = client.create_multipart_upload(Bucket=self.bucket, Key=self.key) master_variable_dict = Variable.get(self.master_variable) master_variable_dict['UploadId'] = multipart['UploadId'] Variable.set(self.master_variable, master_variable_dict)
def test_should_raises_401_unauthenticated(self): Variable.set("delete_var1", 1) # make sure variable is added response = self.client.delete("/api/v1/variables/delete_var1") assert_401(response) # make sure variable is not deleted response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200
def post_variables() -> Response: """ Create a variable """ try: var = variable_schema.load(request.json) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) Variable.set(var.data["key"], var.data["val"]) return variable_schema.dump(var)
def robot_auth(): conn = base_hook('robot_auth') auth = conn['host'] headers = {'content-type': 'application/json'} data = {'username': conn['user'], 'password': conn['password']} r = requests.post(auth, data=json.dumps(data), headers=headers, timeout=10) token = r.json()['access_token'] Variable.set('robot_token', token)
def showQueueName(dag_id, task_id): try: queue = Variable.get('.'.join(['Queue', dag_id, task_id]), default_var=None) except: queue = None if queue is None: queue = AirflowConf.get('celery', 'default_queue') Variable.set('.'.join(['Queue', dag_id, task_id]), queue) return queue
def dag_config(airflow_init_db): variable_name = 'pipe_test' value = dict( project_id='test_project', pipeline_dataset='dataset', pipeline_bucket='bucket', foo='bar', ) Variable.set(variable_name, value, serialize_json=True) return variable_name
def start(bucket, keys, file_paths): timestamp_prefix = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) Variable.set("timestamp", timestamp_prefix) s3 = boto3.client('s3') input_key = keys[0] input_file = file_paths[0] preproc_key = keys[1] preproc_file = file_paths[1] s3.upload_file(Filename=input_file, Bucket=bucket, Key=input_key) s3.upload_file(Filename=preproc_file, Bucket=bucket, Key=preproc_key)
def make_deploy_decision(**kwargs): variable_name = kwargs['params']['deploy_name'] + '-best-metric' task_instance = kwargs['task_instance'] experiment_id = task_instance.xcom_pull(task_ids='train') best_metric = task_instance.xcom_pull(task_ids='wait') last_best = Variable.get(variable_name, default_var=None) if last_best is None or best_metric < float(last_best): Variable.set(variable_name, best_metric) return 'deploy' else: return 'failure'
def espn_data_download(): #w=4 lstart = Variable.get("NFL_START_DATE") w = datetime.today().isocalendar()[1] - datetime.strptime( '2020-9-11', '%Y-%m-%d').isocalendar()[1] Variable.set("week", str(w)) #lid=866268 lid = Variable.get("ESPN_LEAGUE") sc_data = get_scoreboard(w, lid) with open('data/sc_data_{}.json'.format(w), 'w') as outfile: json.dump(sc_data, outfile)
def saving_output_filenames(self): if self.output_files is not None: output_files_var = [] for k, v in self.output_files.items(): # making sure directory exists pathlib.Path(osp.dirname(v.path)).mkdir(parents=True, exist_ok=True) output_files_var.append((k, v.path)) # Saving the output files list as Airflow variable Variable.set(f"{self.get_variables_prefix()}_output_files", json.dumps(output_files_var))
def robot_auth(): conn = BaseHook.get_connection('robot_auth') auth = conn.host headers = {'content-type': 'application/json'} data = {'username': conn.login, 'password': conn.password} r = requests.post(auth, data=json.dumps(data), headers=headers, timeout=10) token = r.json()['access_token'] Variable.set('robot_token', token)
def execute(self, context): self.log.info("Creating EMR cluster cluster={0} at region={1}".format( self.cluster_name, self.region_name)) self.log.info("EMR cluster number_of_nodes={0}".format( self.num_core_nodes)) task_instance = context['task_instance'] cluster_id = self.create_cluster() Variable.set("cluster_id", cluster_id) task_instance.xcom_push('cluster_id', cluster_id) self.log.info("The newly create_cluster_id = {0}".format(cluster_id)) return cluster_id
def test_should_raises_401_unauthenticated(self): Variable.set("var1", "foo") response = self.client.patch( "/api/v1/variables/var1", json={ "key": "var1", "value": "updated", }, ) assert_401(response)
def test_variable_no_encryption(self): """ Test variables without encryption """ Variable.set('key', 'value') session = settings.Session() test_var = session.query(Variable).filter(Variable.key == 'key').one() assert not test_var.is_encrypted assert test_var.val == 'value' # We always call mask_secret for variables, and let the SecretsMasker decide based on the name if it # should mask anything. That logic is tested in test_secrets_masker.py self.mask_secret.assert_called_once_with('value', 'key')
def _processing_user(ti): users_txt = ti.xcom_pull(task_ids=["fetch_user"])[0] users = json.loads(users_txt) if not len(users) or 'results' not in users: raise ValueError("User is empty") user = users['results'][0] user_map = { 'firstname': user['name']['first'], 'lastname': user['name']['last'] } processed_user = json.dumps(user_map) Variable.set("user", processed_user)
def test_write(self): """ Test records can be written and overwritten """ Variable.set(key="test_key", value="test_val") session = settings.Session() result = session.query(RTIF).all() assert [] == result with DAG("test_write", start_date=START_DATE): task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}") rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE)) rtif.write() result = (session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter( RTIF.dag_id == rtif.dag_id, RTIF.task_id == rtif.task_id, RTIF.execution_date == rtif.execution_date, ).first()) assert ('test_write', 'test', { 'bash_command': 'echo test_val', 'env': None }) == result # Test that overwrite saves new values to the DB Variable.delete("test_key") Variable.set(key="test_key", value="test_val_updated") with DAG("test_write", start_date=START_DATE): updated_task = BashOperator( task_id="test", bash_command="echo {{ var.value.test_key }}") rtif_updated = RTIF( TI(task=updated_task, execution_date=EXECUTION_DATE)) rtif_updated.write() result_updated = (session.query( RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter( RTIF.dag_id == rtif_updated.dag_id, RTIF.task_id == rtif_updated.task_id, RTIF.execution_date == rtif_updated.execution_date, ).first()) assert ( 'test_write', 'test', { 'bash_command': 'echo test_val_updated', 'env': None }, ) == result_updated
def check_jail_profiles(output_path, **kwargs): """ In production, this task will check the list received from the scrape_jail task to see who already has a profile in our database. For those who do not, it will produce a list, and it will then divide that list among the appropriate amount of workers via csv files. To keep our dag simple for the final project, I am simply dividing tasks up among the workers and not checking their status in the database """ # load filepaths from required task reqs = requires("scrape_jail", **kwargs) logging.info("Requirements:", str(reqs)) # get info for people in jail (which is stored in people.csv in dir created by 'scrape_jail') with open(reqs["people"], "r", newline="") as fout: data = list(csv.reader(fout)) logging.info("opened people.csv") # user decides scrapes_per_worker depending on personal preference and number of scrapers scrapes_per_worker = int(Variable.get("scrapes_per_worker", default_var=3)) logging.info("scrapes_per_worker = " + str(scrapes_per_worker)) # how many people need their profiles scraped num_people_to_scrape = len(data) logging.info("num_people_to_scrape = " + str(num_people_to_scrape)) # this determines how many tasks we create to do the scraping (which allows it to be done in # parallel when deployed) num_tasks = math.ceil(num_people_to_scrape / scrapes_per_worker) logging.info("num_tasks = " + str(num_tasks)) # split big list of people into 'to do lists" for each of the workers. If worker fails, it can # be re-run and pull in exact same people for x in range(num_tasks): chunk = data[x * scrapes_per_worker + 1:(x + 1) * scrapes_per_worker + 1] out_path = os.path.join(output_path, f"todo_{x}.csv") with open(out_path, "w", newline="") as f: writer = csv.writer(f) writer.writerows(chunk) logging.info(f"wrote todo_{x}") # set variable in Airflow (stored in meta-db) to use in constructing dynamic DAG later Variable.set("num_odyssey_scraping_tasks", num_tasks) # set variable that controls concurrency (basically concurrency proportial to percentage of scrapes up to 1000) max_concurrency = os.cpu_count() - 1 concurrency = (math.ceil(max_concurrency * num_tasks / 1000) if num_tasks / 1000 < 1 else max_concurrency) Variable.set("concurrency", concurrency) return "Complete"
def test_should_delete_variable(self): Variable.set("delete_var1", 1) # make sure variable is added response = self.client.get("/api/v1/variables/delete_var1") assert response.status_code == 200 response = self.client.delete("/api/v1/variables/delete_var1") assert response.status_code == 204 # make sure variable is deleted response = self.client.get("/api/v1/variables/delete_var1") assert response.status_code == 404
def test_variable_set_with_env_variable(self): Variable.set("key", "db-value") with self.assertLogs(variable.log) as log_context: with mock.patch.dict('os.environ', AIRFLOW_VAR_KEY="env-value"): Variable.set("key", "new-db-value") assert "env-value" == Variable.get("key") assert "new-db-value" == Variable.get("key") assert log_context.records[0].message == ( 'You have the environment variable AIRFLOW_VAR_KEY defined, which takes precedence over ' 'reading from the database. The value will be saved, but to read it you have to delete ' 'the environment variable.')
def count_utterances_file_chunks(**kwargs): get_variables() utterances_names = json.loads(Variable.get("utteranceschunkslist")) all_blobs = list_blobs_in_a_path(bucket_name, source_chunk_path) list_of_blobs = [] for blob in all_blobs: if blob.name.endswith(".csv"): list_of_blobs.append(str(blob.name)) print("***The utterances file chunks***", list_of_blobs) utterances_names["utteranceschunkslist"] = list_of_blobs utterances_names = mydict(utterances_names) Variable.set("utteranceschunkslist", utterances_names)
def update_schemas(**kwargs): schemas = get_all_schemas() # we update all schemas that we found: for key, value in schemas.items(): Variable.set(key=key, value=value, serialize_json=True) # now we clean the variables that do not exist anymore: with create_session() as session: current_vars = set(var.key for var in session.query(Variable)) apps_to_delete = current_vars - schemas.keys() print("About to delete old apps: {}".format(apps_to_delete)) for _var in apps_to_delete: Variable.delete(_var, session)
def variables(args): if args.get: try: var = Variable.get(args.get, deserialize_json=args.json, default_var=args.default) print(var) except ValueError as e: print(e) if args.set: Variable.set(args.set[0], args.set[1]) if not args.set and not args.get: # list all variables session = settings.Session() vars = session.query(Variable) msg = "\n".join(var.key for var in vars) print(msg)
def MonthlyGenerateTestArgs(**kwargs): """Loads the configuration that will be used for this Iteration.""" conf = kwargs['dag_run'].conf if conf is None: conf = dict() # If version is overriden then we should use it otherwise we use it's # default or monthly value. version = conf.get('VERSION') or istio_common_dag.GetVariableOrDefault('monthly-version', None) if not version or version == 'INVALID': raise ValueError('version needs to be provided') Variable.set('monthly-version', 'INVALID') #GCS_MONTHLY_STAGE_PATH is of the form ='prerelease/{version}' gcs_path = 'prerelease/%s' % (version) branch = conf.get('BRANCH') or istio_common_dag.GetVariableOrDefault('monthly-branch', None) if not branch or branch == 'INVALID': raise ValueError('branch needs to be provided') Variable.set('monthly-branch', 'INVALID') mfest_commit = conf.get('MFEST_COMMIT') or branch default_conf = environment_config.GetDefaultAirflowConfig( branch=branch, gcs_path=gcs_path, mfest_commit=mfest_commit, pipeline_type='monthly', verify_consistency='true', version=version) config_settings = dict() for name in default_conf.iterkeys(): config_settings[name] = conf.get(name) or default_conf[name] # These are the extra params that are passed to the dags for monthly release monthly_conf = dict() monthly_conf['DOCKER_HUB' ] = 'istio' monthly_conf['GCR_RELEASE_DEST' ] = 'istio-io' monthly_conf['GCS_GITHUB_PATH' ] = 'istio-secrets/github.txt.enc' monthly_conf['RELEASE_PROJECT_ID' ] = 'istio-io' # GCS_MONTHLY_RELEASE_PATH is of the form 'istio-release/releases/{version}' monthly_conf['GCS_MONTHLY_RELEASE_PATH'] = 'istio-release/releases/%s' % (version) for name in monthly_conf.iterkeys(): config_settings[name] = conf.get(name) or monthly_conf[name] testMonthlyConfigSettings(config_settings) return config_settings
def import_helper(filepath): with open(filepath, 'r') as varfile: var = varfile.read() try: d = json.loads(var) except Exception: print("Invalid variables file.") else: try: n = 0 for k, v in d.items(): if isinstance(v, dict): Variable.set(k, v, serialize_json=True) else: Variable.set(k, v) n += 1 except Exception: pass finally: print("{} of {} variables successfully updated.".format(n, len(d)))
def ReportMonthlySuccessful(task_instance, **kwargs): del kwargs version = istio_common_dag.GetSettingPython(task_instance, 'VERSION') try: match = re.match(r'([0-9])\.([0-9])\.([0-9]).*', version) major, minor, patch = match.group(1), match.group(2), match.group(3) Variable.set('major_version', major) Variable.set('released_version_minor', minor) Variable.set('released_version_patch', patch) except (IndexError, AttributeError): logging.error('Could not extract released version infomation. \n' 'Please set airflow version Variables manually.' 'After you are done hit Mark Success.')
def ReportDailySuccessful(task_instance, **kwargs): date = kwargs['execution_date'] latest_run = float(Variable.get('latest_daily_timestamp')) timestamp = time.mktime(date.timetuple()) logging.info('Current run\'s timestamp: %s \n' 'latest_daily\'s timestamp: %s', timestamp, latest_run) if timestamp >= latest_run: Variable.set('latest_daily_timestamp', timestamp) run_sha = task_instance.xcom_pull(task_ids='get_git_commit') latest_version = GetSettingPython(task_instance, 'VERSION') logging.info('setting latest green daily to: %s', run_sha) Variable.set('latest_sha', run_sha) Variable.set('latest_daily', latest_version) logging.info('latest_sha test to %s', run_sha)
def set_group(*args, **kwargs): if datetime.now().hour > 18 or datetime.now().hour < 8: Variable.set('group', 'night_shift') else: Variable.set('group', 'day_shift')
def test_variable_set_get_round_trip_json(self): value = {"a": 17, "b": 47} Variable.set("tested_var_set_id", value, serialize_json=True) assert value == Variable.get("tested_var_set_id", deserialize_json=True)
def test_variable_set_get_round_trip(self): Variable.set("tested_var_set_id", "Monday morning breakfast") assert "Monday morning breakfast" == Variable.get("tested_var_set_id")