예제 #1
0
    def has_asset_key(self, asset_key: AssetKey) -> bool:
        check.inst_param(asset_key, "asset_key", AssetKey)
        if self.has_secondary_index(SECONDARY_INDEX_ASSET_KEY):
            query = (
                db.select([1])
                .where(
                    db.or_(
                        AssetKeyTable.c.asset_key == asset_key.to_string(),
                        AssetKeyTable.c.asset_key == asset_key.to_string(legacy=True),
                    )
                )
                .limit(1)
            )
        else:
            query = (
                db.select([1])
                .where(
                    db.or_(
                        SqlEventLogStorageTable.c.asset_key == asset_key.to_string(),
                        SqlEventLogStorageTable.c.asset_key == asset_key.to_string(legacy=True),
                    )
                )
                .limit(1)
            )

        with self.index_connection() as conn:
            results = conn.execute(query).fetchall()

        return len(results) > 0
예제 #2
0
    def has_asset_key(self, asset_key: AssetKey) -> bool:
        check.inst_param(asset_key, "asset_key", AssetKey)
        query = (db.select([
            AssetKeyTable.c.asset_key, AssetKeyTable.c.asset_details
        ]).where(
            db.or_(
                AssetKeyTable.c.asset_key == asset_key.to_string(),
                AssetKeyTable.c.asset_key == asset_key.to_string(legacy=True),
            )).limit(1))

        with self.index_connection() as conn:
            row = conn.execute(query).fetchone()
            if not row:
                return False

            asset_details: Optional[
                AssetDetails] = AssetDetails.from_db_string(row[1])
            if not asset_details or not asset_details.last_wipe_timestamp:
                return True

            materialization_row = conn.execute(
                db.select([SqlEventLogStorageTable.c.timestamp]).where(
                    db.or_(
                        AssetKeyTable.c.asset_key == asset_key.to_string(),
                        AssetKeyTable.c.asset_key == asset_key.to_string(
                            legacy=True),
                    )).order_by(
                        SqlEventLogStorageTable.c.timestamp.desc()).limit(
                            1)).fetchone()
            if not materialization_row:
                return False

            return utc_datetime_from_naive(
                materialization_row[0]) > utc_datetime_from_timestamp(
                    asset_details.last_wipe_timestamp)
예제 #3
0
    def get_all_asset_keys(self, prefix_path=None):
        if not prefix_path:
            if self.has_secondary_index(SECONDARY_INDEX_ASSET_KEY):
                query = db.select([AssetKeyTable.c.asset_key])
            else:
                query = (db.select([
                    SqlEventLogStorageTable.c.asset_key
                ]).where(
                    SqlEventLogStorageTable.c.asset_key != None).distinct())
        else:
            if self.has_secondary_index(SECONDARY_INDEX_ASSET_KEY):
                query = db.select([AssetKeyTable.c.asset_key]).where(
                    AssetKeyTable.c.asset_key.startswith(
                        AssetKey.get_db_prefix(prefix_path)))
            else:
                query = (db.select([
                    SqlEventLogStorageTable.c.asset_key
                ]).where(SqlEventLogStorageTable.c.asset_key != None).where(
                    SqlEventLogStorageTable.c.asset_key.startswith(
                        AssetKey.get_db_prefix(prefix_path))).distinct())

        with self.connect() as conn:
            results = conn.execute(query).fetchall()
        return [
            AssetKey.from_db_string(asset_key) for (asset_key, ) in results
            if asset_key
        ]
예제 #4
0
    def get_latest_materialization_events(
        self, asset_keys: Sequence[AssetKey]
    ) -> Mapping[AssetKey, Optional[EventLogEntry]]:
        check.list_param(asset_keys, "asset_keys", AssetKey)
        rows = self._fetch_asset_rows(asset_keys=asset_keys)
        to_backcompat_fetch = set()
        results: Dict[AssetKey, Optional[EventLogEntry]] = {}
        for row in rows:
            asset_key = AssetKey.from_db_string(row[0])
            if not asset_key:
                continue
            event_or_materialization = (
                deserialize_json_to_dagster_namedtuple(row[1]) if row[1] else None
            )
            if isinstance(event_or_materialization, EventLogEntry):
                results[asset_key] = event_or_materialization
            else:
                to_backcompat_fetch.add(asset_key)

        if to_backcompat_fetch:
            latest_event_subquery = (
                db.select(
                    [
                        SqlEventLogStorageTable.c.asset_key,
                        db.func.max(SqlEventLogStorageTable.c.timestamp).label("timestamp"),
                    ]
                )
                .where(
                    db.and_(
                        SqlEventLogStorageTable.c.asset_key.in_(
                            [asset_key.to_string() for asset_key in to_backcompat_fetch]
                        ),
                        SqlEventLogStorageTable.c.dagster_event_type
                        == DagsterEventType.ASSET_MATERIALIZATION.value,
                    )
                )
                .group_by(SqlEventLogStorageTable.c.asset_key)
                .subquery()
            )
            backcompat_query = db.select(
                [SqlEventLogStorageTable.c.asset_key, SqlEventLogStorageTable.c.event]
            ).join(
                latest_event_subquery,
                db.and_(
                    SqlEventLogStorageTable.c.asset_key == latest_event_subquery.c.asset_key,
                    SqlEventLogStorageTable.c.timestamp == latest_event_subquery.c.timestamp,
                ),
            )
            with self.index_connection() as conn:
                event_rows = conn.execute(backcompat_query).fetchall()

            for row in event_rows:
                asset_key = AssetKey.from_db_string(row[0])
                if asset_key:
                    results[asset_key] = cast(
                        EventLogEntry, deserialize_json_to_dagster_namedtuple(row[1])
                    )

        return results
예제 #5
0
    def get_asset_keys(self, prefix_path=None):
        lazy_migrate = False

        if not prefix_path:
            if self.has_secondary_index(SECONDARY_INDEX_ASSET_KEY):
                query = db.select([AssetKeyTable.c.asset_key])
            else:
                query = (db.select([
                    SqlEventLogStorageTable.c.asset_key
                ]).where(
                    SqlEventLogStorageTable.c.asset_key != None).distinct())

                # This is in place to migrate everyone to using the secondary index table for asset
                # keys.  Performing this migration should result in a big performance boost for
                # any asset-catalog reads.

                # After a sufficient amount of time (>= 0.11.0?), we can remove the checks
                # for has_secondary_index(SECONDARY_INDEX_ASSET_KEY) and always read from the
                # AssetKeyTable, since we are already writing to the table. Tracking the conditional
                # check removal here: https://github.com/dagster-io/dagster/issues/3507
                lazy_migrate = True
        else:
            if self.has_secondary_index(SECONDARY_INDEX_ASSET_KEY):
                query = db.select([AssetKeyTable.c.asset_key]).where(
                    db.or_(
                        AssetKeyTable.c.asset_key.startswith(
                            AssetKey.get_db_prefix(prefix_path)),
                        AssetKeyTable.c.asset_key.startswith(
                            AssetKey.get_db_prefix(prefix_path, legacy=True)),
                    ))
            else:
                query = (db.select([
                    SqlEventLogStorageTable.c.asset_key
                ]).where(SqlEventLogStorageTable.c.asset_key != None).where(
                    db.or_(
                        SqlEventLogStorageTable.c.asset_key.startswith(
                            AssetKey.get_db_prefix(prefix_path)),
                        SqlEventLogStorageTable.c.asset_key.startswith(
                            AssetKey.get_db_prefix(prefix_path, legacy=True)),
                    )).distinct())

        with self.index_connection() as conn:
            results = conn.execute(query).fetchall()

        if lazy_migrate:
            # This is in place to migrate everyone to using the secondary index table for asset
            # keys.  Performing this migration should result in a big performance boost for
            # any subsequent asset-catalog reads.
            self._lazy_migrate_secondary_index_asset_key(
                [asset_key for (asset_key, ) in results if asset_key])
        return list(
            set([
                AssetKey.from_db_string(asset_key) for (asset_key, ) in results
                if asset_key
            ]))
예제 #6
0
    def all_asset_keys(self):
        with self.index_connection() as conn:
            results = conn.execute(
                db.select([AssetKeyTable.c.asset_key, AssetKeyTable.c.asset_details])
            ).fetchall()

            asset_keys = set()
            wiped = set()
            wiped_timestamps = {}
            for result in results:
                asset_key = AssetKey.from_db_string(result[0])
                asset_details: Optional[AssetDetails] = AssetDetails.from_db_string(result[1])
                asset_keys.add(asset_key)
                if asset_details and asset_details.last_wipe_timestamp:
                    wiped_timestamps[asset_key] = asset_details.last_wipe_timestamp

            if wiped_timestamps:
                materialized_timestamps = {}

                # fetch the last materialization timestamp per asset key
                materialization_results = conn.execute(
                    db.select(
                        [
                            SqlEventLogStorageTable.c.asset_key,
                            db.func.max(SqlEventLogStorageTable.c.timestamp),
                        ]
                    )
                    .where(
                        SqlEventLogStorageTable.c.asset_key.in_(
                            [asset_key.to_string() for asset_key in wiped_timestamps.keys()]
                        )
                    )
                    .group_by(SqlEventLogStorageTable.c.asset_key)
                    .order_by(db.func.max(SqlEventLogStorageTable.c.timestamp).asc())
                ).fetchall()

                for result in materialization_results:
                    asset_key = AssetKey.from_db_string(result[0])
                    last_materialized_timestamp = result[1]
                    materialized_timestamps[asset_key] = last_materialized_timestamp

                # calculate the set of wiped asset keys that have not had a materialization since
                # the wipe timestamp
                wiped = set(
                    [
                        asset_key
                        for asset_key in wiped_timestamps.keys()
                        if not materialized_timestamps.get(asset_key)
                        or utc_datetime_from_naive(materialized_timestamps.get(asset_key))
                        < utc_datetime_from_timestamp(wiped_timestamps[asset_key])
                    ]
                )

        return list(asset_keys.difference(wiped))
예제 #7
0
    def has_asset_key(self, asset_key: AssetKey) -> bool:
        check.inst_param(asset_key, "asset_key", AssetKey)
        query = (db.select([1]).where(
            db.or_(
                AssetKeyTable.c.asset_key == asset_key.to_string(),
                AssetKeyTable.c.asset_key == asset_key.to_string(legacy=True),
            )).limit(1))

        with self.index_connection() as conn:
            results = conn.execute(query).fetchall()

        return len(results) > 0
예제 #8
0
    def __call__(self, fn: Callable) -> AssetsDefinition:
        asset_name = self.name or fn.__name__

        ins_by_input_names: Mapping[str, In] = build_asset_ins(
            fn, self.namespace, self.ins or {}, self.non_argument_deps)

        partition_fn: Optional[Callable] = None
        if self.partitions_def:

            def partition_fn(context):  # pylint: disable=function-redefined
                return [context.partition_key]

        out = Out(
            asset_key=AssetKey(list(filter(None,
                                           [self.namespace, asset_name]))),
            metadata=self.metadata or {},
            io_manager_key=self.io_manager_key,
            dagster_type=self.dagster_type,
            asset_partitions_def=self.partitions_def,
            asset_partitions=partition_fn,
        )
        op = _Op(
            name=asset_name,
            description=self.description,
            ins={
                input_name: in_def
                for input_name, in_def in ins_by_input_names.items()
            },  # convert Mapping object to dict
            out=out,
            required_resource_keys=self.required_resource_keys,
            tags={"kind": self.compute_kind} if self.compute_kind else None,
        )(fn)

        out_asset_key = AssetKey(
            list(filter(None, [self.namespace, asset_name])))
        return AssetsDefinition(
            input_names_by_asset_key={
                in_def.asset_key: input_name
                for input_name, in_def in ins_by_input_names.items()
            },
            output_names_by_asset_key={out_asset_key: "result"},
            op=op,
            partitions_def=self.partitions_def,
            partition_mappings={
                ins_by_input_names[input_name].asset_key: partition_mapping
                for input_name, partition_mapping in
                self.partition_mappings.items()
            } if self.partition_mappings else None,
        )
예제 #9
0
    def get_all_asset_keys(self, prefix_path=None):
        if not prefix_path:
            query = db.select([SqlEventLogStorageTable.c.asset_key]).distinct()
        else:
            query = (db.select([SqlEventLogStorageTable.c.asset_key]).where(
                SqlEventLogStorageTable.c.asset_key.startswith(
                    AssetKey.get_db_prefix(prefix_path))).distinct())

        with self.connect() as conn:
            results = conn.execute(query).fetchall()

        return [
            AssetKey.from_db_string(asset_key) for (asset_key, ) in results
            if asset_key
        ]
예제 #10
0
파일: asset.py 프로젝트: zuik/dagster
def asset_wipe_command(key, **cli_args):
    if not cli_args.get("all") and len(key) == 0:
        raise click.UsageError(
            "Error, you must specify an asset key or use `--all` to wipe all asset keys."
        )

    if cli_args.get("all") and len(key) > 0:
        raise click.UsageError("Error, cannot use more than one of: asset key, `--all`.")

    with DagsterInstance.get() as instance:
        if len(key) > 0:
            asset_keys = [AssetKey.from_db_string(key_string) for key_string in key]
            prompt = (
                "Are you sure you want to remove the asset key indexes for these keys from the event "
                "logs? Type DELETE"
            )
        else:
            asset_keys = instance.all_asset_keys()
            prompt = "Are you sure you want to remove all asset key indexes from the event logs? Type DELETE"

        confirmation = click.prompt(prompt)
        if confirmation == "DELETE":
            with DagsterInstance.get() as instance:
                instance.wipe_assets(asset_keys)
                click.echo("Removed asset indexes from event logs")
        else:
            click.echo("Exiting without removing asset indexes")
예제 #11
0
파일: decorators.py 프로젝트: keyz/dagster
    def inner(fn: Callable[..., Any]) -> AssetsDefinition:
        asset_name = name or fn.__name__
        ins_by_input_names: Mapping[str, In] = build_asset_ins(
            fn, None, ins or {}, non_argument_deps)

        op = _Op(
            name=asset_name,
            description=description,
            ins={
                input_name: in_def
                for input_name, in_def in ins_by_input_names.items()
            },  # convert Mapping object to dict
            out=outs,
            required_resource_keys=required_resource_keys,
            tags={"kind": compute_kind} if compute_kind else None,
        )(fn)

        return AssetsDefinition(
            input_names_by_asset_key={
                in_def.asset_key: input_name
                for input_name, in_def in ins_by_input_names.items()
            },
            output_names_by_asset_key={
                out.asset_key if isinstance(out.asset_key, AssetKey) else
                AssetKey([name]): name
                for name, out in outs.items()
            },
            op=op,
        )
예제 #12
0
파일: asset.py 프로젝트: xyzlat/dagster
def asset_wipe_command(key, **cli_args):
    if not cli_args.get('all') and len(key) == 0:
        raise click.UsageError(
            'Error, you must specify an asset key or use `--all` to wipe all asset keys.'
        )

    if cli_args.get('all') and len(key) > 0:
        raise click.UsageError(
            'Error, cannot use more than one of: asset key, `--all`.')

    if len(key) > 0:
        asset_keys = [
            AssetKey.from_db_string(key_string) for key_string in key
        ]
        prompt = (
            'Are you sure you want to remove the asset key indexes for these keys from the event '
            'logs? Type DELETE')
    else:
        asset_keys = None
        prompt = (
            'Are you sure you want to remove all asset key indexes from the event logs? Type DELETE'
        )

    confirmation = click.prompt(prompt)
    if confirmation == 'DELETE':
        instance = DagsterInstance.get()
        if asset_keys:
            instance.wipe_assets(asset_keys)
        else:
            instance.wipe_all_assets()
        click.echo('Removed asset indexes from event logs')
    else:
        click.echo('Exiting without removing asset indexes')
예제 #13
0
    def set_asset(self, context, step_output_handle, obj, asset_metadata):
        """Pickle the data and store the object to a custom file path.

        This method emits an AssetMaterialization event so the assets will be tracked by the
        Asset Catalog.
        """
        check.inst_param(step_output_handle, "step_output_handle", StepOutputHandle)
        path = check.str_param(asset_metadata.get("path"), "asset_metadata.path")

        filepath = self._get_path(path)

        # Ensure path exists
        mkdir_p(os.path.dirname(filepath))

        with open(filepath, self.write_mode) as write_obj:
            pickle.dump(obj, write_obj, PICKLE_PROTOCOL)

        return AssetMaterialization(
            asset_key=AssetKey(
                [
                    context.pipeline_def.name,
                    step_output_handle.step_key,
                    step_output_handle.output_name,
                ]
            ),
            metadata_entries=[EventMetadataEntry.fspath(os.path.abspath(filepath))],
        )
예제 #14
0
def test_source_asset_partitions():
    hourly_asset = SourceAsset(
        AssetKey("hourly_asset"),
        partitions_def=HourlyPartitionsDefinition(
            start_date="2021-05-05-00:00"),
    )

    @asset(partitions_def=DailyPartitionsDefinition(start_date="2021-05-05"))
    def daily_asset(hourly_asset):
        assert hourly_asset is None

    class CustomIOManager(IOManager):
        def handle_output(self, context, obj):
            pass

        def load_input(self, context):
            key_range = context.asset_partition_key_range
            assert key_range.start == "2021-06-06-00:00"
            assert key_range.end == "2021-06-06-23:00"

    daily_job = build_assets_job(
        name="daily_job",
        assets=[daily_asset],
        source_assets=[hourly_asset],
        resource_defs={
            "io_manager":
            IOManagerDefinition.hardcoded_io_manager(CustomIOManager())
        },
    )
    assert daily_job.execute_in_process(partition_key="2021-06-06").success
예제 #15
0
    def handle_output(self, context, obj):
        """Pickle the data and store the object to a custom file path.

        This method emits an AssetMaterialization event so the assets will be tracked by the
        Asset Catalog.
        """
        check.inst_param(context, "context", OutputContext)
        metadata = context.metadata
        path = check.str_param(metadata.get("path"), "metadata.path")

        filepath = self._get_path(path)

        # Ensure path exists
        mkdir_p(os.path.dirname(filepath))
        context.log.debug(f"Writing file at: {filepath}")

        with open(filepath, self.write_mode) as write_obj:
            pickle.dump(obj, write_obj, PICKLE_PROTOCOL)

        return AssetMaterialization(
            asset_key=AssetKey(
                [context.pipeline_name, context.step_key, context.name]),
            metadata_entries=[
                EventMetadataEntry.fspath(os.path.abspath(filepath))
            ],
        )
예제 #16
0
def test_single_partitioned_asset_job():
    partitions_def = StaticPartitionsDefinition(["a", "b", "c", "d"])

    class MyIOManager(IOManager):
        def handle_output(self, context, obj):
            assert context.asset_partition_key == "b"

        def load_input(self, context):
            assert False, "shouldn't get here"

    @asset(partitions_def=partitions_def)
    def my_asset():
        pass

    my_job = build_assets_job(
        "my_job",
        assets=[my_asset],
        resource_defs={
            "io_manager":
            IOManagerDefinition.hardcoded_io_manager(MyIOManager())
        },
    )
    result = my_job.execute_in_process(partition_key="b")
    assert result.asset_materializations_for_node("my_asset") == [
        AssetMaterialization(asset_key=AssetKey(["my_asset"]), partition="b")
    ]
예제 #17
0
def build_asset_ins(
    fn: Callable,
    asset_namespace: Optional[Sequence[str]],
    asset_ins: Mapping[str, AssetIn],
    non_argument_deps: Optional[Set[AssetKey]],
) -> Dict[str, In]:

    non_argument_deps = check.opt_set_param(non_argument_deps,
                                            "non_argument_deps", AssetKey)

    params = get_function_params(fn)
    is_context_provided = len(params) > 0 and params[
        0].name in get_valid_name_permutations("context")
    input_param_names = [
        input_param.name
        for input_param in (params[1:] if is_context_provided else params)
    ]

    all_input_names = set(input_param_names) | asset_ins.keys()

    for in_key in asset_ins.keys():
        if in_key not in input_param_names:
            raise DagsterInvalidDefinitionError(
                f"Key '{in_key}' in provided ins dict does not correspond to any of the names "
                "of the arguments to the decorated function")

    ins: Dict[str, In] = {}
    for input_name in all_input_names:
        asset_key = None

        if input_name in asset_ins:
            asset_key = asset_ins[input_name].asset_key
            metadata = asset_ins[input_name].metadata or {}
            namespace = asset_ins[input_name].namespace
            dagster_type = None
        else:
            metadata = {}
            namespace = None
            dagster_type = None

        asset_key = asset_key or AssetKey(
            list(
                filter(None,
                       [*(namespace or asset_namespace or []), input_name])))

        ins[input_name] = In(
            metadata=metadata,
            root_manager_key="root_manager",
            asset_key=asset_key,
            dagster_type=dagster_type,
        )

    for asset_key in non_argument_deps:
        stringified_asset_key = "_".join(asset_key.path)
        if stringified_asset_key:
            ins[stringified_asset_key] = In(dagster_type=Nothing,
                                            asset_key=asset_key)

    return ins
예제 #18
0
def test_two_partitioned_assets_job():
    @asset(partitions_def=StaticPartitionsDefinition(["a", "b", "c", "d"]))
    def upstream():
        pass

    @asset(partitions_def=StaticPartitionsDefinition(["a", "b", "c", "d"]))
    def downstream(upstream):
        assert upstream is None

    my_job = build_assets_job("my_job", assets=[upstream, downstream])
    result = my_job.execute_in_process(partition_key="b")
    assert result.asset_materializations_for_node("upstream") == [
        AssetMaterialization(AssetKey(["upstream"]), partition="b")
    ]
    assert result.asset_materializations_for_node("downstream") == [
        AssetMaterialization(AssetKey(["downstream"]), partition="b")
    ]
예제 #19
0
 def mutate(self, graphene_info, **kwargs):
     return wipe_assets(
         graphene_info,
         [
             AssetKey.from_graphql_input(asset_key)
             for asset_key in kwargs["assetKeys"]
         ],
     )
예제 #20
0
    def get_all_asset_keys(self):
        query = db.select([SqlEventLogStorageTable.c.asset_key]).distinct()
        with self.connect() as conn:
            results = conn.execute(query).fetchall()

        return [
            AssetKey.from_db_string(asset_key) for (asset_key, ) in results
            if asset_key
        ]
예제 #21
0
 def get_asset_keys(
     self,
     prefix: Optional[List[str]] = None,
     limit: Optional[int] = None,
     cursor: Optional[str] = None,
 ) -> Iterable[AssetKey]:
     rows = self._fetch_asset_rows(prefix=prefix, limit=limit, cursor=cursor)
     asset_keys = [AssetKey.from_db_string(row[0]) for row in sorted(rows, key=lambda x: x[0])]
     return [asset_key for asset_key in asset_keys if asset_key]
예제 #22
0
    def delete_events_for_run(self, conn, run_id):
        check.str_param(run_id, "run_id")

        delete_statement = (
            SqlEventLogStorageTable.delete().where(  # pylint: disable=no-value-for-parameter
                SqlEventLogStorageTable.c.run_id == run_id
            )
        )
        removed_asset_key_query = (
            db.select([SqlEventLogStorageTable.c.asset_key])
            .where(SqlEventLogStorageTable.c.run_id == run_id)
            .where(SqlEventLogStorageTable.c.asset_key != None)
            .group_by(SqlEventLogStorageTable.c.asset_key)
        )

        removed_asset_keys = [
            AssetKey.from_db_string(row[0])
            for row in conn.execute(removed_asset_key_query).fetchall()
        ]
        conn.execute(delete_statement)
        if len(removed_asset_keys) > 0:
            keys_to_check = []
            keys_to_check.extend([key.to_string() for key in removed_asset_keys])
            keys_to_check.extend([key.to_string(legacy=True) for key in removed_asset_keys])
            remaining_asset_keys = [
                AssetKey.from_db_string(row[0])
                for row in conn.execute(
                    db.select([SqlEventLogStorageTable.c.asset_key])
                    .where(SqlEventLogStorageTable.c.asset_key.in_(keys_to_check))
                    .group_by(SqlEventLogStorageTable.c.asset_key)
                )
            ]
            to_remove = set(removed_asset_keys) - set(remaining_asset_keys)
            if to_remove:
                keys_to_remove = []
                keys_to_remove.extend([key.to_string() for key in to_remove])
                keys_to_remove.extend([key.to_string(legacy=True) for key in to_remove])
                conn.execute(
                    AssetKeyTable.delete().where(  # pylint: disable=no-value-for-parameter
                        AssetKeyTable.c.asset_key.in_(keys_to_remove)
                    )
                )
예제 #23
0
def test_single_partitioned_asset_job():
    partitions_def = StaticPartitionsDefinition(["a", "b", "c", "d"])

    @asset(partitions_def=partitions_def)
    def my_asset():
        pass

    my_job = build_assets_job("my_job", assets=[my_asset])
    result = my_job.execute_in_process(partition_key="b")
    assert result.asset_materializations_for_node("my_asset") == [
        AssetMaterialization(asset_key=AssetKey(["my_asset"]), partition="b")
    ]
예제 #24
0
 def _lazy_migrate_secondary_index_asset_key(self, conn, asset_keys):
     results = conn.execute(db.select([AssetKeyTable.c.asset_key
                                       ])).fetchall()
     existing = [asset_key for (asset_key, ) in results if asset_key]
     to_migrate = set(asset_keys) - set(existing)
     for asset_key in to_migrate:
         try:
             conn.execute(AssetKeyTable.insert().values(  # pylint: disable=no-value-for-parameter
                 asset_key=AssetKey.from_db_string(asset_key).to_string()))
         except db.exc.IntegrityError:
             # asset key already present
             pass
     self.enable_secondary_index(SECONDARY_INDEX_ASSET_KEY)
예제 #25
0
    def all_asset_tags(self):
        query = db.select([AssetKeyTable.c.asset_key, AssetKeyTable.c.last_materialization])
        tags_by_asset_key = defaultdict(dict)
        with self.index_connection() as conn:
            rows = conn.execute(query).fetchall()
            for asset_key, json_str in rows:
                materialization = self._asset_materialization_from_json_column(json_str)
                if materialization and materialization.tags:
                    tags_by_asset_key[AssetKey.from_db_string(asset_key)] = {
                        k: v for k, v in materialization.tags.items()
                    }

        return tags_by_asset_key
예제 #26
0
    def all_asset_tags(self):
        query = db.select([AssetKeyTable.c.asset_key, AssetKeyTable.c.last_materialization])
        tags_by_asset_key = defaultdict(dict)
        with self.index_connection() as conn:
            rows = conn.execute(query).fetchall()
            for asset_key, materialization_str in rows:
                if materialization_str:
                    materialization = deserialize_json_to_dagster_namedtuple(materialization_str)
                    tags_by_asset_key[AssetKey.from_db_string(asset_key)] = {
                        k: v for k, v in (materialization.tags or {}).items()
                    }

        return tags_by_asset_key
예제 #27
0
    def __call__(self, fn: Callable) -> AssetsDefinition:
        asset_name = self.name or fn.__name__

        asset_ins = build_asset_ins(fn, self.namespace, self.ins or {},
                                    self.non_argument_deps)

        partition_fn: Optional[Callable] = None
        if self.partitions_def:

            def partition_fn(context):  # pylint: disable=function-redefined
                return [context.partition_key]

        out_asset_key = AssetKey(
            list(filter(None, [*(self.namespace or []), asset_name])))
        out = Out(
            asset_key=out_asset_key,
            metadata=self.metadata or {},
            io_manager_key=self.io_manager_key,
            dagster_type=self.dagster_type,
            asset_partitions_def=self.partitions_def,
            asset_partitions=partition_fn,
        )
        op = _Op(
            name="__".join(out_asset_key.path),
            description=self.description,
            ins=asset_ins,
            out=out,
            required_resource_keys=self.required_resource_keys,
            tags={"kind": self.compute_kind} if self.compute_kind else None,
            config_schema={
                "assets": {
                    "input_partitions": Field(dict, is_required=False),
                    "output_partitions": Field(dict, is_required=False),
                }
            },
        )(fn)

        return AssetsDefinition(
            input_names_by_asset_key={
                in_def.asset_key: input_name
                for input_name, in_def in asset_ins.items()
            },
            output_names_by_asset_key={out_asset_key: "result"},
            op=op,
            partitions_def=self.partitions_def,
            partition_mappings={
                asset_ins[input_name].asset_key: partition_mapping
                for input_name, partition_mapping in
                self.partition_mappings.items()
            } if self.partition_mappings else None,
        )
예제 #28
0
def test_assets_with_same_partitioning():
    partitions_def = StaticPartitionsDefinition(["a", "b", "c", "d"])

    @asset(partitions_def=partitions_def)
    def upstream_asset():
        pass

    @asset(partitions_def=partitions_def)
    def downstream_asset(upstream_asset):
        assert upstream_asset

    assert get_upstream_partitions_for_partition_range(
        downstream_asset,
        upstream_asset,
        AssetKey("upstream_asset"),
        PartitionKeyRange("a", "c"),
    ) == PartitionKeyRange("a", "c")

    assert get_downstream_partitions_for_partition_range(
        downstream_asset,
        upstream_asset,
        AssetKey("upstream_asset"),
        PartitionKeyRange("a", "c"),
    ) == PartitionKeyRange("a", "c")
예제 #29
0
    def all_asset_tags(self):
        tags_by_asset_key = defaultdict(dict)
        if self.has_secondary_index(ASSET_KEY_INDEX_COLS):
            query = (db.select([
                AssetKeyTable.c.asset_key, AssetKeyTable.c.tags
            ]).where(AssetKeyTable.c.tags != None).where(
                db.or_(
                    AssetKeyTable.c.wipe_timestamp == None,
                    AssetKeyTable.c.last_materialization_timestamp >
                    AssetKeyTable.c.wipe_timestamp,
                )))
            with self.index_connection() as conn:
                rows = conn.execute(query).fetchall()
                for asset_key, tags_json in rows:
                    tags = seven.json.loads(tags_json)
                    if tags:
                        tags_by_asset_key[AssetKey.from_db_string(
                            asset_key)] = tags

        else:
            query = db.select([
                AssetKeyTable.c.asset_key, AssetKeyTable.c.last_materialization
            ])
            with self.index_connection() as conn:
                rows = conn.execute(query).fetchall()
                for asset_key, json_str in rows:
                    materialization = self._asset_materialization_from_json_column(
                        json_str)
                    if materialization and materialization.tags:
                        tags_by_asset_key[AssetKey.from_db_string(
                            asset_key)] = {
                                k: v
                                for k, v in materialization.tags.items()
                            }

        return tags_by_asset_key
예제 #30
0
def build_asset_outs(
    op_name: str,
    outs: Mapping[str, Out],
    ins: Mapping[str, In],
    internal_asset_deps: Mapping[str, Set[AssetKey]],
) -> Dict[str, Out]:

    # if an AssetKey is not supplied, create one based off of the out's name
    asset_keys_by_out_name = {
        out_name: out.asset_key
        if isinstance(out.asset_key, AssetKey) else AssetKey([out_name])
        for out_name, out in outs.items()
    }

    # update asset_key if necessary, add metadata indicating inter asset deps
    outs = {
        out_name: out._replace(
            asset_key=asset_keys_by_out_name[out_name],
            metadata=dict(
                **(out.metadata or {}),
                **({
                    ASSET_DEPENDENCY_METADATA_KEY:
                    internal_asset_deps[out_name]
                } if out_name in internal_asset_deps else {}),
            ),
        )
        for out_name, out in outs.items()
    }

    # validate that the internal_asset_deps make sense
    valid_asset_deps = set(in_def.asset_key for in_def in ins.values())
    valid_asset_deps.update(asset_keys_by_out_name.values())
    for out_name, asset_keys in internal_asset_deps.items():
        check.invariant(
            out_name in outs,
            f"Invalid out key '{out_name}' supplied to `internal_asset_deps` argument for multi-asset "
            f"{op_name}. Must be one of the outs for this multi-asset {list(outs.keys())}.",
        )
        invalid_asset_deps = asset_keys.difference(valid_asset_deps)
        check.invariant(
            not invalid_asset_deps,
            f"Invalid asset dependencies: {invalid_asset_deps} specified in `internal_asset_deps` "
            f"argument for multi-asset '{op_name}' on key '{out_name}'. Each specified asset key "
            "must be associated with an input to the asset or produced by this asset. Valid "
            f"keys: {valid_asset_deps}",
        )

    return outs