コード例 #1
0
    def purge_ticks(self, origin_id, selector_id, tick_status, before):
        check.str_param(origin_id, "origin_id")
        check.inst_param(tick_status, "tick_status", TickStatus)
        check.float_param(before, "before")

        utc_before = utc_datetime_from_timestamp(before)

        base_query = (
            JobTickTable.delete()  # pylint: disable=no-value-for-parameter
            .where(JobTickTable.c.status == tick_status.value).where(
                JobTickTable.c.timestamp < utc_before))

        if self.has_instigators_table():
            query = base_query.where(
                db.or_(
                    JobTickTable.c.selector_id == selector_id,
                    db.and_(
                        JobTickTable.c.selector_id == None,
                        JobTickTable.c.job_origin_id == origin_id,
                    ),
                ))
        else:
            query = base_query.where(JobTickTable.c.job_origin_id == origin_id)

        with self.connect() as conn:
            conn.execute(query)
コード例 #2
0
 def __init__(self, timeout_length=1.0, sleep_length=0.1):
     self._log_queue_lock = gevent.lock.Semaphore()
     self._log_sequence = LogSequence()
     self._is_dequeueing_blocked = False
     self._queue_timeout = time.time()
     self._timeout_length = check.float_param(timeout_length, 'timeout_length')
     self._sleep_length = check.float_param(sleep_length, 'sleep_length')
コード例 #3
0
ファイル: events.py プロジェクト: danieldiamond/dagster
 def float(value, label, description=None):
     check.float_param(value, 'value')
     return EventMetadataEntry(
         label,
         description,
         FloatMetadataEntryData(value),
     )
コード例 #4
0
ファイル: pynotify.py プロジェクト: vitorarrais/dagster
def await_pg_notifications(
    conn_string: str,
    channels: Optional[List[str]] = None,
    timeout: float = 5.0,
    yield_on_timeout: bool = False,
    exit_event: Optional[Event] = None,
) -> Iterator[Optional[Notify]]:
    """Subscribe to PostgreSQL notifications, and handle them
    in infinite-loop style.

    Args:
        conn_string (str): connection string to PG DB
        channels (Optional[List[str]], optional): List of channel names to listen to. Defaults to None.
        timeout (float, optional): Timeout interval. Defaults to 5.0.
        yield_on_timeout (bool, optional): Should the function yield on timeout. Defaults to False.
        exit_event (Optional[Event], optional): Event that indicates that polling for new notifications should stop. Defaults to None.

    Yields:
        Iterator[Optional[Notify]]: Can yield one of two types:
            1: None, in case of timeout
            2: Notify, in case of successful notification reception
    """

    check.str_param(conn_string, "conn_string")
    channels = None if channels is None else check.list_param(channels, "channels", of_type=str)
    check.float_param(timeout, "timeout")
    check.bool_param(yield_on_timeout, "yield_on_timeout")

    conn = get_conn(conn_string)

    if channels:
        start_listening(conn, channels)

    try:

        while True and not (exit_event and exit_event.is_set()):
            try:
                r, w, x = select.select([conn], [], [], max(0, timeout))
                if (r, w, x) == ([], [], []):
                    if yield_on_timeout:
                        yield None

                if conn in r:
                    conn.poll()

                    # copy the conn.notifies list/queue & empty it
                    notify_list, conn.notifies = conn.notifies, []
                    for notif in notify_list:
                        yield notif

            except select.error as e:
                if e.errno == errno.EINTR:
                    pass
                else:
                    raise
    finally:
        conn.close()
コード例 #5
0
    def execution_plan_step_success(self, step_key, millis):
        check.str_param(step_key, 'step_key')
        check.float_param(millis, 'millis')

        self.context.info(
            'Execution of {step_key} succeeded in {millis}'.format(
                step_key=step_key, millis=millis),
            event_type=EventType.EXECUTION_PLAN_STEP_SUCCESS.value,
            millis=millis,
            step_key=step_key,
        )
コード例 #6
0
    def purge_job_ticks(self, job_origin_id, tick_status, before):
        check.str_param(job_origin_id, "job_origin_id")
        check.inst_param(tick_status, "tick_status", JobTickStatus)
        check.float_param(before, "before")

        utc_before = utc_datetime_from_timestamp(before)

        with self.connect() as conn:
            conn.execute(
                JobTickTable.delete()  # pylint: disable=no-value-for-parameter
                .where(JobTickTable.c.status == tick_status.value).where(
                    JobTickTable.c.timestamp < utc_before).where(
                        JobTickTable.c.job_origin_id == job_origin_id))
コード例 #7
0
ファイル: dag_gen.py プロジェクト: trevenrawr/dagster
def generate_pipeline(name, size, connect_factor=1.0):
    check.int_param(size, "size")
    check.invariant(size > 3,
                    "Can not create pipelines with less than 3 nodes")
    check.float_param(connect_factor, "connect_factor")

    random.seed(name)

    # generate nodes
    solids = {}
    for i in range(size):
        num_inputs = random.randint(1, 3)
        num_outputs = random.randint(1, 3)
        num_cfg = random.randint(0, 5)
        solid_id = "{}_solid_{}".format(name, i)
        solids[solid_id] = generate_solid(
            solid_id=solid_id,
            num_inputs=num_inputs,
            num_outputs=num_outputs,
            num_cfg=num_cfg,
        )

    solid_ids = list(solids.keys())
    # connections
    deps = defaultdict(dict)
    for i in range(int(size * connect_factor)):
        # choose output
        out_idx = random.randint(0, len(solid_ids) - 2)
        out_solid_id = solid_ids[out_idx]
        output_solid = solids[out_solid_id]
        output_name = output_solid.output_defs[random.randint(
            0,
            len(output_solid.output_defs) - 1)].name

        # choose input
        in_idx = random.randint(out_idx + 1, len(solid_ids) - 1)
        in_solid_id = solid_ids[in_idx]
        input_solid = solids[in_solid_id]
        input_name = input_solid.input_defs[random.randint(
            0,
            len(input_solid.input_defs) - 1)].name

        # map
        deps[in_solid_id][input_name] = DependencyDefinition(
            out_solid_id, output_name)

    return PipelineDefinition(name=name,
                              solid_defs=list(solids.values()),
                              dependencies=deps)
コード例 #8
0
ファイル: scheduler.py プロジェクト: boltsource/dagster
    def __new__(cls, schedule_name, cron_schedule, timestamp, status, run_id=None, error=None):
        '''
        This class defines the data that is serialized and stored in ``ScheduleStorage``. We depend
        on the schedule storage implementation to provide schedule tick ids, and therefore
        seperate all other data into this serializable class that can be stored independently of the
        id

        Arguments:
            schedule_name (str): The name of the schedule for this tick
            cron_schedule (str): The cron schedule of the ``ScheduleDefinition`` for tracking
                purposes. This is helpful when debugging changes in the cron schedule.
            timestamp (float): The timestamp at which this schedule execution started
            status (ScheduleTickStatus): The status of the tick, which can be updated

        Keyword Arguments:
            run_id (str): The run created by the tick. This is set only when the status is
                ``ScheduleTickStatus.SUCCESS``
            error (SerializableErrorInfo): The error caught during schedule execution. This is set
                onle when the status is ``ScheduleTickStatus.Failure``
        '''

        _validate_schedule_tick_args(status, run_id, error)
        return super(ScheduleTickData, cls).__new__(
            cls,
            check.str_param(schedule_name, 'schedule_name'),
            check.str_param(cron_schedule, 'cron_schedule'),
            check.float_param(timestamp, 'timestamp'),
            status,
            run_id,
            error,
        )
コード例 #9
0
 def __new__(
     cls,
     backfill_id,
     partition_set_origin,
     status,
     partition_names,
     from_failure,
     reexecution_steps,
     tags,
     backfill_timestamp,
     last_submitted_partition_name=None,
     error=None,
 ):
     return super(PartitionBackfill, cls).__new__(
         cls,
         check.str_param(backfill_id, "backfill_id"),
         check.inst_param(partition_set_origin, "partition_set_origin",
                          ExternalPartitionSetOrigin),
         check.inst_param(status, "status", BulkActionStatus),
         check.list_param(partition_names, "partition_names", of_type=str),
         check.bool_param(from_failure, "from_failure"),
         check.opt_list_param(reexecution_steps,
                              "reexecution_steps",
                              of_type=str),
         check.opt_dict_param(tags, "tags", key_type=str, value_type=str),
         check.float_param(backfill_timestamp, "backfill_timestamp"),
         check.opt_str_param(last_submitted_partition_name,
                             "last_submitted_partition_name"),
         check.opt_inst_param(error, "error", SerializableErrorInfo),
     )
コード例 #10
0
ファイル: log.py プロジェクト: yingjiebyron/dagster
 def __new__(
     cls,
     error_info,
     message,
     level,
     user_message,
     run_id,
     timestamp,
     step_key=None,
     pipeline_name=None,
     dagster_event=None,
 ):
     return super(EventRecord, cls).__new__(
         cls,
         check.opt_inst_param(error_info, "error_info",
                              SerializableErrorInfo),
         check.str_param(message, "message"),
         coerce_valid_log_level(level),
         check.str_param(user_message, "user_message"),
         check.str_param(run_id, "run_id"),
         check.float_param(timestamp, "timestamp"),
         check.opt_str_param(step_key, "step_key"),
         check.opt_str_param(pipeline_name, "pipeline_name"),
         check.opt_inst_param(dagster_event, "dagster_event", DagsterEvent),
     )
コード例 #11
0
ファイル: runs.py プロジェクト: tristaneljed/dagster
 def __new__(cls, run_id, timestamp, pipeline_name):
     return super(DagsterRunMeta, cls).__new__(
         cls,
         check.str_param(run_id, 'run_id'),
         check.float_param(timestamp, 'timestamp'),
         check.str_param(pipeline_name, 'pipeline_name'),
     )
コード例 #12
0
    def __init__(
        self,
        error_info,
        message,
        level,
        user_message,
        run_id,
        timestamp,
        step_key=None,
        pipeline_name=None,
        dagster_event=None,
    ):
        from dagster.core.events import DagsterEvent

        self._error_info = check.opt_inst_param(error_info, 'error_info',
                                                SerializableErrorInfo)
        self._message = check.str_param(message, 'message')
        self._level = check_valid_level_param(level)
        self._user_message = check.str_param(user_message, 'user_message')
        self._run_id = check.str_param(run_id, 'run_id')
        self._timestamp = check.float_param(timestamp, 'timestamp')
        self._step_key = check.opt_str_param(step_key, 'step_key')
        self._pipeline_name = check.opt_str_param(pipeline_name,
                                                  'pipeline_name')
        self._dagster_event = check.opt_inst_param(dagster_event,
                                                   'dagster_event',
                                                   DagsterEvent)
コード例 #13
0
ファイル: log.py プロジェクト: trevenrawr/dagster
    def __new__(
        cls,
        error_info,
        level,
        user_message,
        run_id,
        timestamp,
        step_key=None,
        pipeline_name=None,
        dagster_event=None,
        job_name=None,
    ):
        if pipeline_name and job_name:
            raise DagsterInvariantViolationError(
                "Provided both `pipeline_name` and `job_name` parameters to `EventLogEntry` "
                "initialization. Please provide only one or the other.")

        pipeline_name = pipeline_name or job_name
        return super(EventLogEntry, cls).__new__(
            cls,
            check.opt_inst_param(error_info, "error_info",
                                 SerializableErrorInfo),
            coerce_valid_log_level(level),
            check.str_param(user_message, "user_message"),
            check.str_param(run_id, "run_id"),
            check.float_param(timestamp, "timestamp"),
            check.opt_str_param(step_key, "step_key"),
            check.opt_str_param(pipeline_name, "pipeline_name"),
            check.opt_inst_param(dagster_event, "dagster_event", DagsterEvent),
        )
コード例 #14
0
 def __new__(
     cls,
     error_info,
     message,
     level,
     user_message,
     run_id,
     timestamp,
     step_key=None,
     pipeline_name=None,
     dagster_event=None,
 ):
     return super(EventRecord, cls).__new__(
         cls,
         check.opt_inst_param(error_info, 'error_info',
                              SerializableErrorInfo),
         check.str_param(message, 'message'),
         coerce_valid_log_level(level),
         check.str_param(user_message, 'user_message'),
         check.str_param(run_id, 'run_id'),
         check.float_param(timestamp, 'timestamp'),
         check.opt_str_param(step_key, 'step_key'),
         check.opt_str_param(pipeline_name, 'pipeline_name'),
         check.opt_inst_param(dagster_event, 'dagster_event', DagsterEvent),
     )
コード例 #15
0
ファイル: events.py プロジェクト: wslulciuc/dagster
 def __init__(self, error_info, message, level, user_message, event_type, run_id, timestamp):
     self._error_info = check.opt_inst_param(error_info, 'error_info', SerializableErrorInfo)
     self._message = check.str_param(message, 'message')
     self._level = check_valid_level_param(level)
     self._user_message = check.str_param(user_message, 'user_message')
     self._event_type = check.inst_param(event_type, 'event_type', EventType)
     self._run_id = check.str_param(run_id, 'run_id')
     self._timestamp = check.float_param(timestamp, 'timestamp')
コード例 #16
0
 def __new__(cls, timestamp, daemon_type, daemon_id, error):
     return super(DaemonHeartbeat, cls).__new__(
         cls,
         timestamp=check.float_param(timestamp, "timestamp"),
         daemon_type=check.inst_param(daemon_type, "daemon_type",
                                      DaemonType),
         daemon_id=daemon_id,
         error=check.opt_inst_param(error, "error", SerializableErrorInfo),
     )
コード例 #17
0
ファイル: instigation.py プロジェクト: keyz/dagster
    def __new__(
        cls,
        job_origin_id,
        job_name,
        job_type,
        status,
        timestamp,
        run_ids=None,
        run_keys=None,
        error=None,
        skip_reason=None,
        cursor=None,
        origin_run_ids=None,
        failure_count=None,
    ):
        """
        This class defines the data that is serialized and stored in ``JobStorage``. We depend
        on the job storage implementation to provide job tick ids, and therefore
        separate all other data into this serializable class that can be stored independently of the
        id

        Arguments:
            job_origin_id (str): The id of the job target for this tick
            job_name (str): The name of the job for this tick
            job_type (InstigatorType): The type of this job for this tick
            status (TickStatus): The status of the tick, which can be updated
            timestamp (float): The timestamp at which this job evaluation started

        Keyword Arguments:
            run_id (str): The run created by the tick.
            error (SerializableErrorInfo): The error caught during job execution. This is set
                only when the status is ``TickStatus.Failure``
            skip_reason (str): message for why the tick was skipped
            origin_run_ids (List[str]): The runs originating the job.
            failure_count (int): The number of times this tick has failed. If the status is not
                FAILED, this is the number of previous failures before it reached the current state.
        """
        _validate_job_tick_args(job_type, status, run_ids, error, skip_reason)
        return super(TickData, cls).__new__(
            cls,
            check.str_param(job_origin_id, "job_origin_id"),
            check.str_param(job_name, "job_name"),
            check.inst_param(job_type, "job_type", InstigatorType),
            check.inst_param(status, "status", TickStatus),
            check.float_param(timestamp, "timestamp"),
            check.opt_list_param(run_ids, "run_ids", of_type=str),
            check.opt_list_param(run_keys, "run_keys", of_type=str),
            error,  # validated in _validate_job_tick_args
            skip_reason,  # validated in _validate_job_tick_args
            cursor=check.opt_str_param(cursor, "cursor"),
            origin_run_ids=check.opt_list_param(origin_run_ids,
                                                "origin_run_ids",
                                                of_type=str),
            failure_count=check.opt_int_param(failure_count, "failure_count",
                                              0),
        )
コード例 #18
0
    def __new__(cls, timestamp, daemon_type, daemon_id, errors=None):
        errors = check.opt_list_param(errors, "errors", of_type=SerializableErrorInfo)

        return super(DaemonHeartbeat, cls).__new__(
            cls,
            timestamp=check.float_param(timestamp, "timestamp"),
            daemon_type=check.str_param(daemon_type, "daemon_type"),
            daemon_id=daemon_id,
            errors=errors,
        )
コード例 #19
0
 def __new__(cls, schedule_type, start, timezone, fmt, end_offset):
     return super(ExternalTimeWindowPartitionsDefinitionData, cls).__new__(
         cls,
         schedule_type=check.inst_param(schedule_type, "schedule_type",
                                        ScheduleType),
         start=check.float_param(start, "start"),
         timezone=check.opt_str_param(timezone, "timezone"),
         fmt=check.str_param(fmt, "fmt"),
         end_offset=check.int_param(end_offset, "end_offset"),
     )
コード例 #20
0
 def __new__(cls, process_or_error, loadable_target_origin,
             creation_timestamp, server_id):
     return super(ProcessRegistryEntry, cls).__new__(
         cls,
         check.inst_param(process_or_error, "process_or_error",
                          (GrpcServerProcess, SerializableErrorInfo)),
         check.inst_param(loadable_target_origin, "loadable_target_origin",
                          LoadableTargetOrigin),
         check.float_param(creation_timestamp, "creation_timestamp"),
         check.opt_str_param(server_id, "server_id"),
     )
コード例 #21
0
 def __new__(
     cls, result: DbtResult, state: str, start: str, end: str, elapsed: float,
 ):
     return super().__new__(
         cls,
         result,
         check.str_param(state, "state"),
         check.str_param(start, "start"),
         check.str_param(end, "end"),
         check.float_param(elapsed, "elapsed"),
     )
コード例 #22
0
    def execution_time_iterator(self, start_timestamp):
        check.float_param(start_timestamp, "start_timestamp")

        timezone_str = (
            self.execution_timezone if self.execution_timezone else pendulum.now().timezone.name
        )

        start_datetime = pendulum.from_timestamp(start_timestamp, tz=timezone_str)

        date_iter = croniter(self.cron_schedule, start_datetime)

        # Go back one iteration so that the next iteration is the first time that is >= start_datetime
        # and matches the cron schedule
        date_iter.get_prev(datetime.datetime)

        while True:
            next_date = pendulum.instance(date_iter.get_next(datetime.datetime)).in_tz(timezone_str)

            # During DST transitions, croniter returns datetimes that don't actually match the
            # cron schedule, so add a guard here
            if croniter.match(self.cron_schedule, next_date):
                yield next_date
コード例 #23
0
ファイル: executor.py プロジェクト: sarahmk125/dagster
    def __init__(
        self,
        retries,
        broker=None,
        backend=None,
        include=None,
        config_source=None,
        job_config=None,
        job_namespace=None,
        load_incluster_config=False,
        kubeconfig_file=None,
        repo_location_name=None,
        job_wait_timeout=None,
    ):

        if load_incluster_config:
            check.invariant(
                kubeconfig_file is None,
                "`kubeconfig_file` is set but `load_incluster_config` is True.",
            )
        else:
            check.opt_str_param(kubeconfig_file, "kubeconfig_file")

        self._retries = check.inst_param(retries, "retries", RetryMode)
        self.broker = check.opt_str_param(broker, "broker", default=broker_url)
        self.backend = check.opt_str_param(backend,
                                           "backend",
                                           default=result_backend)
        self.include = check.opt_list_param(include, "include", of_type=str)
        self.config_source = dict_wrapper(
            dict(DEFAULT_CONFIG,
                 **check.opt_dict_param(config_source, "config_source")))
        self.job_config = check.inst_param(job_config, "job_config",
                                           DagsterK8sJobConfig)
        self.job_namespace = check.opt_str_param(job_namespace,
                                                 "job_namespace",
                                                 default="default")

        self.load_incluster_config = check.bool_param(load_incluster_config,
                                                      "load_incluster_config")

        self.kubeconfig_file = check.opt_str_param(kubeconfig_file,
                                                   "kubeconfig_file")
        self.repo_location_name = check.str_param(repo_location_name,
                                                  "repo_location_name")
        self.job_wait_timeout = check.float_param(job_wait_timeout,
                                                  "job_wait_timeout")
コード例 #24
0
    def __new__(
        cls,
        job_origin_id,
        job_name,
        job_type,
        status,
        timestamp,
        run_ids=None,
        run_keys=None,
        error=None,
        skip_reason=None,
    ):
        """
        This class defines the data that is serialized and stored in ``JobStorage``. We depend
        on the job storage implementation to provide job tick ids, and therefore
        separate all other data into this serializable class that can be stored independently of the
        id

        Arguments:
            job_origin_id (str): The id of the job target for this tick
            job_name (str): The name of the job for this tick
            job_type (JobType): The type of this job for this tick
            status (JobTickStatus): The status of the tick, which can be updated
            timestamp (float): The timestamp at which this job evaluation started

        Keyword Arguments:
            run_id (str): The run created by the tick.
            error (SerializableErrorInfo): The error caught during job execution. This is set
                only when the status is ``JobTickStatus.Failure``
            skip_reason (str): message for why the tick was skipped
        """

        _validate_job_tick_args(job_type, status, run_ids, error, skip_reason)
        return super(JobTickData, cls).__new__(
            cls,
            check.str_param(job_origin_id, "job_origin_id"),
            check.str_param(job_name, "job_name"),
            check.inst_param(job_type, "job_type", JobType),
            check.inst_param(status, "status", JobTickStatus),
            check.float_param(timestamp, "timestamp"),
            check.opt_list_param(run_ids, "run_ids", of_type=str),
            check.opt_list_param(run_keys, "run_keys", of_type=str),
            error,  # validated in _validate_job_tick_args
            skip_reason,  # validated in _validate_job_tick_args
        )
コード例 #25
0
ファイル: types.py プロジェクト: trevenrawr/dagster
    def __new__(
        cls,
        timestamp: float,
        daemon_type: str,
        daemon_id: str,
        errors: Optional[List[SerializableErrorInfo]] = None,
    ):
        errors = check.opt_list_param(errors,
                                      "errors",
                                      of_type=SerializableErrorInfo)

        return super(DaemonHeartbeat, cls).__new__(
            cls,
            timestamp=check.float_param(timestamp, "timestamp"),
            daemon_type=check.str_param(daemon_type, "daemon_type"),
            daemon_id=check.opt_str_param(daemon_id, "daemon_id"),
            errors=errors,
        )
コード例 #26
0
 def __new__(
     cls,
     instigator_origin_id: str,
     instigator_name: str,
     instigator_type: InstigatorType,
     status: TickStatus,
     timestamp: float,
     run_ids: Optional[List[str]] = None,
     run_keys: Optional[List[str]] = None,
     error: Optional[SerializableErrorInfo] = None,
     skip_reason: Optional[str] = None,
     cursor: Optional[str] = None,
     origin_run_ids: Optional[List[str]] = None,
     failure_count: Optional[int] = None,
     selector_id: Optional[str] = None,
 ):
     _validate_tick_args(instigator_type, status, run_ids, error,
                         skip_reason)
     return super(TickData, cls).__new__(
         cls,
         check.str_param(instigator_origin_id, "instigator_origin_id"),
         check.str_param(instigator_name, "instigator_name"),
         check.inst_param(instigator_type, "instigator_type",
                          InstigatorType),
         check.inst_param(status, "status", TickStatus),
         check.float_param(timestamp, "timestamp"),
         check.opt_list_param(run_ids, "run_ids", of_type=str),
         check.opt_list_param(run_keys, "run_keys", of_type=str),
         error,  # validated in _validate_tick_args
         skip_reason,  # validated in _validate_tick_args
         cursor=check.opt_str_param(cursor, "cursor"),
         origin_run_ids=check.opt_list_param(origin_run_ids,
                                             "origin_run_ids",
                                             of_type=str),
         failure_count=check.opt_int_param(failure_count, "failure_count",
                                           0),
         selector_id=check.opt_str_param(selector_id, "selector_id"),
     )
コード例 #27
0
ファイル: __init__.py プロジェクト: xaniasd/dagster
 def __init__(self, timestamp):
     super(DauphinScheduleFutureTick,
           self).__init__(timestamp=check.float_param(
               timestamp, "timestamp"), )
コード例 #28
0
ファイル: events.py プロジェクト: danieldiamond/dagster
 def __new__(cls, value):
     check.float_param(value, 'value')
     return super(FloatMetadataEntryData,
                  cls).__new__(cls, check.float_param(value, 'value'))
コード例 #29
0
ファイル: test_check.py プロジェクト: shcheklein/dagster
def test_float_param():
    assert check.float_param(-1.0, 'param_name') == -1.0
    assert check.float_param(0.0, 'param_name') == 0.0
    assert check.float_param(1.1, 'param_name') == 1.1

    with pytest.raises(ParameterCheckError):
        check.float_param(None, 'param_name')

    with pytest.raises(ParameterCheckError):
        check.float_param('s', 'param_name')

    with pytest.raises(ParameterCheckError):
        check.float_param(1, 'param_name')

    with pytest.raises(ParameterCheckError):
        check.float_param(0, 'param_name')
コード例 #30
0
ファイル: test_check.py プロジェクト: iamahern/dagster
def test_float_param():
    assert check.float_param(-1.0, "param_name") == -1.0
    assert check.float_param(0.0, "param_name") == 0.0
    assert check.float_param(1.1, "param_name") == 1.1

    with pytest.raises(ParameterCheckError):
        check.float_param(None, "param_name")

    with pytest.raises(ParameterCheckError):
        check.float_param("s", "param_name")

    with pytest.raises(ParameterCheckError):
        check.float_param(1, "param_name")

    with pytest.raises(ParameterCheckError):
        check.float_param(0, "param_name")