示例#1
0
    def test_context_removed_after_exit(self):
        example_context = {"Hello": "World"}

        with set_current_context(example_context):
            pass
        with pytest.raises(AirflowException, ):
            get_current_context()
示例#2
0
def do_backfill(sql, intervals):
    import multiprocessing as mp
    from ctypes import c_bool
    params = get_current_context()['params']
    schema = params['schema']
    table = params['table']

    args = [(sql, i) for i in intervals[:-1]]

    serving_sql = sql.replace(f'{schema}.{table}_HISTORY', f'{schema}.{table}').\
        replace('ASOF_TS','LAST_UPDATE_TS').\
        replace('INGEST_TS,', '').\
        replace('CURRENT_TIMESTAMP,','')
    args.append((serving_sql, intervals[-1]))

    failed = mp.Value(c_bool, False)

    with mp.Pool(10, initializer=init_globals, initargs=(failed, )) as p:
        try:
            p.starmap(handle_runs, args)
        except Exception as e:
            print(str(e))
            cleanup(schema, table)
            raise AirflowException(
                "An error occurred during the backfill process. "
                "The Feature Set tables have been cleared.")
示例#3
0
 def task_a():
     ctx = get_current_context()
     logger.info(
         "Context **************************************************")
     logger.info(pformat(ctx))
     logger.info("dag_run.conf ******************************")
     logger.info(pformat(getattr(ctx.get("dag_run"), "conf", None)))
     return {"a_return_value": "test"}
示例#4
0
 def my_task():
     context = get_current_context()
     ds = context["ds"]
     OUTPUT_DIR = WORKING_DIR + "/data/fasta/last100k-by-collection-date"
     default_args["params"]["output-dir"] = OUTPUT_DIR
     default_args["params"][
         "meta-output"] = OUTPUT_DIR + '/master-no-sequences.json'
     default_args["params"]["sequence-output"] = OUTPUT_DIR + '/sequences'
示例#5
0
 def print_logs():
     ti = get_current_context()
     for db in DBS:
         msg = ti["ti"].xcom_pull(
             task_ids=f"query", dag_id=f"dag_id_{db}",
             key=f"{db}_rows_count", include_prior_dates=True,
         )
         print(f"the pulled message is: {msg}")
    def getTodayDate():
        """
        gets the current context of Airflow task. This context will be used to get the execution date.

        """
        context = {"test_date": get_current_context()["ds"]}
        print(context)
        return context
示例#7
0
def set_up():
    ti = get_current_context()['ti']
    dms_client = boto3.client('dms')

    rds_instance_endpoint = _create_rds_instance()
    _create_rds_table(rds_instance_endpoint)
    instance_arn = _create_dms_replication_instance(ti, dms_client)
    _create_dms_endpoints(ti, dms_client, rds_instance_endpoint)
    _await_setup_assets(dms_client, instance_arn)
示例#8
0
def _check_and_transform_video_ids(task_output):
    video_ids_response = task_output
    video_ids = [item['id']['videoId'] for item in video_ids_response['items']]

    if video_ids:
        context = get_current_context()
        context["task_instance"].xcom_push(key='video_ids', value={'id': ','.join(video_ids)})
        return 'video_data_to_s3'
    return 'no_video_ids'
示例#9
0
def _delete_dms_assets(dms_client):
    ti = get_current_context()['ti']
    replication_instance_arn = ti.xcom_pull(key='replication_instance_arn')
    source_arn = ti.xcom_pull(key='source_endpoint_arn')
    target_arn = ti.xcom_pull(key='target_endpoint_arn')

    print('Deleting DMS assets.')
    dms_client.delete_replication_instance(ReplicationInstanceArn=replication_instance_arn)
    dms_client.delete_endpoint(EndpointArn=source_arn)
    dms_client.delete_endpoint(EndpointArn=target_arn)
    def print_data(top_selling_game, best_seller_EU, top_publishers_jp,
                   best_platforms, num_games):
        context = get_current_context()
        date = context['ds']

        return print(
            f'The top selling game for {date} is {top_selling_game}',
            f'The best selling genre in EU for the {date} is(are) {best_seller_EU}',
            f'The publisher(s) with the highest sales in JP for {date} is {top_publishers_jp}',
            f'Platforms with sales > 1M in NA for {date} are {best_platforms}',
            f'The number of games sold in EU more than in JP for the {date} is {num_games}'
        )
示例#11
0
        def dpti_gdi_loop_prepare():
            # Variable.set(self.var_name, 'run')

            context = get_current_context()
            dag_run = context['params']
            task0_dict = dag_run['task_dict_list'][0]
            task1_dict = dag_run['task_dict_list'][1]

            submission_dict = dag_run['submission_dict']
            # prepare_return = True
            # return (task0_dict, task1_dict)
            return {'task0_dict': task0_dict, 'task1_dict': task1_dict}
示例#12
0
        def wrapper(*f_args, **f_kwargs):
            print(f'f_args: {f_args}, f_kwargs: {f_kwargs}')
            # must make sure edges in f_args are upacked and their contents passed to the functions *args and kwargs go to kwargs
            skip = True
            if self.impure:
                skip = False
            incoming_edges = []
            f_args = list(f_args)
            for idx, arg in enumerate(f_args):
                if type(arg) is Edge:
                    incoming_edges.append(arg)
                    f_args[idx] = arg.data
            for key, kwarg in f_kwargs.items():
                if type(kwarg) is Edge:
                    incoming_edges.append(kwarg)
                    f_kwargs[key] = kwarg.data
            incoming_edges_ran = [edge.ran_this_time for edge in incoming_edges]
            if any(incoming_edges_ran):
                skip = False

            super_edge = Edge.union(incoming_edges)

            total_config = get_current_context()['params']

            sig = inspect.signature(func)
            func_params = sorted(sig.parameters.keys())

            cumulative_config = super_edge.cumulative_params
            for func_param in func_params:
                if func_param in total_config.keys():
                    cumulative_config[func_param] = total_config[func_param]

            # sort dict before stringing to make hashing permutatio invariant.
            if cumulative_config:
                cumulative_config = OrderedDict(sorted(cumulative_config.items(), key=lambda t: t[0]))

            func_param_string = str(cumulative_config).encode('ascii')
            cache_dir_name = f'{func.__name__}_{hashlib.md5(func_param_string).hexdigest()}'
            cache_dir = os.path.join(self.root_path, cache_dir_name)
            if not os.path.exists(cache_dir):
                skip = False
                os.mkdir(cache_dir)
            if skip:
                skip = Edge.already_ran(cache_dir, cumulative_config)
            if skip:
                return Edge.decache(cache_dir)
            else:
                print(f'f_args: {f_args}, cumulative_config: {cumulative_config}')
                result = func(*f_args, **cumulative_config)
                edge = Edge(cumulative_params=cumulative_config, data=result, ran_this_time=True)
                edge.cache(cache_dir)
                return edge
示例#13
0
def get_sql():
    from splicemachine.features import FeatureStore

    params = get_current_context()['params']
    schema = params['schema']
    table = params['table']

    fs = FeatureStore()

    print("Getting backfill sql")
    sql = fs.get_backfill_sql(schema, table)
    print("Getting backfill intervals")
    intervals = fs.get_backfill_intervals(schema, table)
    return {"statement": sql, "params": intervals}
示例#14
0
def create_dms_assets():
    print('Creating DMS assets.')
    ti = get_current_context()['ti']
    dms_client = boto3.client('dms')
    rds_instance_endpoint = _get_rds_instance_endpoint()

    print('Creating replication instance.')
    instance_arn = dms_client.create_replication_instance(
        ReplicationInstanceIdentifier=DMS_REPLICATION_INSTANCE_NAME,
        ReplicationInstanceClass='dms.t3.micro',
    )['ReplicationInstance']['ReplicationInstanceArn']

    ti.xcom_push(key='replication_instance_arn', value=instance_arn)

    print('Creating DMS source endpoint.')
    source_endpoint_arn = dms_client.create_endpoint(
        EndpointIdentifier=SOURCE_ENDPOINT_IDENTIFIER,
        EndpointType='source',
        EngineName=RDS_ENGINE,
        Username=RDS_USERNAME,
        Password=RDS_PASSWORD,
        ServerName=rds_instance_endpoint['Address'],
        Port=rds_instance_endpoint['Port'],
        DatabaseName=RDS_DB_NAME,
    )['Endpoint']['EndpointArn']

    print('Creating DMS target endpoint.')
    target_endpoint_arn = dms_client.create_endpoint(
        EndpointIdentifier=TARGET_ENDPOINT_IDENTIFIER,
        EndpointType='target',
        EngineName='s3',
        S3Settings={
            'BucketName': S3_BUCKET,
            'BucketFolder': PROJECT_NAME,
            'ServiceAccessRoleArn': ROLE_ARN,
            'ExternalTableDefinition': json.dumps(TABLE_DEFINITION),
        },
    )['Endpoint']['EndpointArn']

    ti.xcom_push(key='source_endpoint_arn', value=source_endpoint_arn)
    ti.xcom_push(key='target_endpoint_arn', value=target_endpoint_arn)

    print("Awaiting replication instance provisioning.")
    dms_client.get_waiter('replication_instance_available').wait(
        Filters=[{
            'Name': 'replication-instance-arn',
            'Values': [instance_arn]
        }])
示例#15
0
    def pause_all_dags():
        """
        #### Pause Loop

        Get the list of DAGs and pause them one-by-one.
        """
        context = get_current_context()

        # Run command to list all
        dags_json = check_output('airflow dags list -o json', shell=True)
        dags = json.loads(dags_json)

        for d in dags:
            if d['dag_id'] != context['dag'].dag_id and not d['paused']:
                print('Requesting pause of DAG: %s' % d['dag_id'])
                check_output('airflow dags pause %s' % d['dag_id'], shell=True)
示例#16
0
def get_empty_submission(job_work_dir):
    context = get_current_context()
    dag_run = context['params']
    work_base_dir = dag_run['work_base_dir']

    with open(os.path.join(work_base_dir, 'machine.json'), 'r') as f:
        mdata = json.load(f)
    machine = Machine.load_from_dict(mdata['machine'])
    resources = Resources.load_from_dict(mdata['resources'])

    submission = Submission(
        work_base=job_work_dir, 
        resources=resources, 
        machine=machine, 
    )
    return submission
示例#17
0
    def transform(multiple_outputs=True) -> Sequence:
        """
        #### Transform task
        A simple Transform task which takes in the collection of order data and
        computes the total order value.
        """
        context = get_current_context()
        ti = context["ti"]
        data: Sequence = ti.xcom_pull(task_ids="rest-call-ex01",
                                      key="return_value")

        res = []

        for value in data:
            res.append(value)

        return data
示例#18
0
        def dpti_gdi_loop_md(task_dict):
            context = get_current_context()
            dag_run = context['params']

            submission_dict = dag_run['submission_dict']
            print('submission_dict', submission_dict)
            mdata = dag_run['mdata']
            print('mdata', mdata)
            print('debug:task_dict', task_dict)

            machine = Machine.load_from_machine_dict(mdata)
            batch = machine.batch
            submission = Submission.deserialize(
                submission_dict=submission_dict, batch=batch)
            submission.register_task(task=Task.deserialize(
                task_dict=task_dict))
            submission.run_submission()
            # md_return = prepare_return
            return True
示例#19
0
 def test_nested_context(self):
     """
     Nested execution context should be supported in case the user uses multiple context managers.
     Each time the execute method of an operator is called, we set a new 'current' context.
     This test verifies that no matter how many contexts are entered - order is preserved
     """
     max_stack_depth = 15
     ctx_list = []
     for i in range(max_stack_depth):
         # Create all contexts in ascending order
         new_context = {"ContextId": i}
         # Like 15 nested with statements
         ctx_obj = set_current_context(new_context)
         ctx_obj.__enter__()  # pylint: disable=E1101
         ctx_list.append(ctx_obj)
     for i in reversed(range(max_stack_depth)):
         # Iterate over contexts in reverse order - stack is LIFO
         ctx = get_current_context()
         assert ctx["ContextId"] == i
         # End of with statement
         ctx_list[i].__exit__(None, None, None)
示例#20
0
def all_start_check():
    context = get_current_context()
    print(context)

    dag_run = context['params']
    work_base_dir = dag_run['work_base_dir']
    target_temp = int(dag_run['target_temp'])
    target_pres = int(dag_run['target_pres'])
    conf_lmp = str(dag_run['conf_lmp'])
    ti_path = str(dag_run['ti_path'])
    ens = str(dag_run['ens'])
    if_liquid = dag_run['if_liquid']


    work_base_abs_dir = os.path.realpath(work_base_dir)

    dag_work_dirname=str(target_temp)+'K-'+str(target_pres)+'bar-'+str(conf_lmp)
    dag_work_dir=os.path.join(work_base_abs_dir, dag_work_dirname)

    assert os.path.isdir(work_base_dir) is True,  f'work_base_dir {work_base_dir} must exist '
    if os.path.isdir(dag_work_dir) is False:
        os.mkdir(dag_work_dir)
    else:
        pass

    conf_lmp_abs_path = os.path.join(work_base_abs_dir, conf_lmp)
    assert os.path.isfile(conf_lmp_abs_path) is True,  f'structure file {conf_lmp_abs_path} must exist'
    assert str(ti_path) in ["t", "p"], f'value for "path" must be "t" or "p" '

    start_info = dict(work_base_dir=work_base_dir, 
        target_temp=target_temp,
        target_pres=target_pres, 
        conf_lmp=conf_lmp, 
        ti_path=ti_path,
        ens=ens, 
        if_liquid=if_liquid, 
        work_base_abs_dir=work_base_abs_dir,
        dag_work_dir=dag_work_dir)
    print('start_info:', start_info)
    return start_info
示例#21
0
def delete_dms_assets():
    ti = get_current_context()['ti']
    dms_client = boto3.client('dms')
    replication_instance_arn = ti.xcom_pull(key='replication_instance_arn')
    source_arn = ti.xcom_pull(key='source_endpoint_arn')
    target_arn = ti.xcom_pull(key='target_endpoint_arn')

    print('Deleting DMS assets.')
    dms_client.delete_replication_instance(
        ReplicationInstanceArn=replication_instance_arn)
    dms_client.delete_endpoint(EndpointArn=source_arn)
    dms_client.delete_endpoint(EndpointArn=target_arn)

    print('Awaiting DMS assets tear-down.')
    dms_client.get_waiter('replication_instance_deleted').wait(
        Filters=[{
            'Name': 'replication-instance-id',
            'Values': [DMS_REPLICATION_INSTANCE_NAME]
        }])
    dms_client.get_waiter('endpoint_deleted').wait(Filters=[{
        'Name':
        'endpoint-id',
        'Values': [SOURCE_ENDPOINT_IDENTIFIER, TARGET_ENDPOINT_IDENTIFIER],
    }])
示例#22
0
    def test_current_context_roundtrip(self):
        example_context = {"Hello": "World"}

        with set_current_context(example_context):
            assert get_current_context() == example_context
示例#23
0
 def whom() -> str:
     context = get_current_context()
     return context["params"]["who"].capitalize()
示例#24
0
 def update_ip_access_unnormal_address_task():
     context = get_current_context()
     execution_date = (context['execution_date']+timedelta(days=1)).strftime('%Y%m%d')
     if int(execution_date)<=20210216:return
     update_ip_access_unnormal_address()
     return True
示例#25
0
 def backend_info_analysis_task():
     context = get_current_context()
     execution_date = (context['execution_date']+timedelta(days=1)).strftime('%Y%m%d')
     if int(execution_date)<=20210216:return
     backend_info_analysis()
     return True
示例#26
0
 def generate_execution_date():
     context = get_current_context()
     execution_date = (context['execution_date']+timedelta(days=1)).strftime('%Y%m%d')
     return execution_date
示例#27
0
 def test_current_context_no_context_raise(self):
     with pytest.raises(AirflowException):
         get_current_context()
def print_value(value):
    """
        able to get context
    """
    ctx = get_current_context()
    log.info("The knights of Ni say: %s (at %s)", value, ctx['ts'])
示例#29
0
def get_all_the_context(**context):
    current_context = get_current_context()
    assert context == current_context
示例#30
0
 def execute(self, context):
     assert context == get_current_context()