def _refresh_python_profile_stats(self, refresh_stats):
     """Helper function to load in the most recent python stats via the python stats reader.
     """
     if refresh_stats:
         get_logger().info("Refreshing python profile stats.")
         self.python_profile_stats = self.python_stats_reader.load_python_profile_stats(
         )
    def _get_time_interval_for_step(self, start_step, end_step):
        """
        Use python timeline files to get time interval for a step interval
        """
        event_list = self.get_events(
            0,
            time.time() * CONVERT_TO_MICROSECS,
            TimeUnits.MICROSECONDS,
            file_suffix_filter=[PYTHONTIMELINE_SUFFIX],
        )

        start_time_us = end_time_us = None
        event_list.sort(key=lambda x: x.start_time)
        for event in event_list:
            if (hasattr(event, "event_args") and event.event_args is not None
                    and "step_num" in event.event_args):
                # get the start time of start step
                if start_time_us is None and start_step == int(
                        event.event_args["step_num"]):
                    start_time_us = event.start_time
                # get the time just before the start of end_step
                if end_time_us is None and (end_step == int(
                        event.event_args["step_num"])):
                    end_time_us = event.start_time - 1
            if start_time_us is not None and end_time_us is not None:
                break

        if start_time_us is None or end_time_us is None:
            get_logger().info(
                f"Invalid step interval [{start_step}, {end_step}]")
            start_time_us = end_time_us = 0
        return start_time_us, end_time_us
示例#3
0
def get_json_config_as_dict(json_config_path) -> Dict:
    """Checks json_config_path, then environment variables, then attempts to load.

    Will throw FileNotFoundError if a config is not available.
    """
    if json_config_path is not None:
        path = json_config_path
    else:
        path = os.getenv(CONFIG_FILE_PATH_ENV_STR, DEFAULT_CONFIG_FILE_PATH)
    with open(path) as json_config_file:
        params_dict = json.load(json_config_file)
    get_logger().info(f"Creating hook from json_config at {path}.")
    return params_dict
    def _dump_stats(self, stats_dir):
        """Dump the stats as a JSON dictionary to a file `python_stats.json` in the provided stats directory.
        """
        stats_file_path = os.path.join(stats_dir, PYINSTRUMENT_JSON_FILENAME)
        html_file_path = os.path.join(stats_dir, PYINSTRUMENT_HTML_FILENAME)
        try:
            session = self._profiler.last_session
            json_stats = JSONRenderer().render(session)
            get_logger().info(
                f"JSON stats collected for pyinstrument: {json_stats}.")
            with open(stats_file_path, "w") as json_data:
                json_data.write(json_stats)
            get_logger().info(
                f"Dumping pyinstrument stats to {stats_file_path}.")

            with open(html_file_path, "w") as html_data:
                html_data.write(self._profiler.output_html())
            get_logger().info(
                f"Dumping pyinstrument output html to {html_file_path}.")
        except (UnboundLocalError, AssertionError):
            # Handles error that sporadically occurs within pyinstrument.
            get_logger().info(
                f"The pyinstrument profiling session has been corrupted for: {stats_file_path}."
            )
            with open(stats_file_path, "w") as json_data:
                json.dump({"root_frame": None}, json_data)

            with open(html_file_path, "w") as html_data:
                html_data.write("An error occurred during profiling!")
示例#5
0
    def get_device_usage_stats(self, device=None, utilization_ranges=None):
        """
        Find the usage spread based on utilization ranges. If ranges are not provided,
        >90, 10-90, <10 are considered
        :param device: List of Resource.cpu, Resource.gpu. Type: Resource
        :param utilization_ranges: list of tuples
        """
        if (device is not None) and (not isinstance(device, (list, Resource))):
            get_logger().info(f"{device} should be of type list or Resource")
            return pd.DataFrame()

        if device is None:
            resources = [Resource.CPU.value, Resource.GPU.value]
        else:
            if isinstance(device, Resource):
                device = [device]
            resources = [x.value for x in device]

        if utilization_ranges is None:
            utilization_ranges = [(90, 100), (10, 90), (0, 10)]
        if not isinstance(utilization_ranges, list):
            get_logger().info(
                f"{utilization_ranges} should be a list of tuples containing the ranges"
            )
            return pd.DataFrame()
        if len(utilization_ranges) == 0:
            get_logger().info(f"{utilization_ranges} cannot be empty")
            return pd.DataFrame()
        if any(
                len(utilization_range) != 2
                for utilization_range in utilization_ranges):
            get_logger().info(
                f"Each interval in {utilization_ranges} must have a start and end value"
            )
            return pd.DataFrame()

        def helper(x, util_ranges):
            for start, end in util_ranges:
                if start <= float(x) <= end:
                    return (start, end)
            return ()

        self.sys_metrics_df["ranges"] = self.sys_metrics_df.apply(
            lambda x: helper(x["value"], utilization_ranges), axis=1)
        device_sys_df = self.sys_metrics_df[self.sys_metrics_df["ranges"] !=
                                            ()]

        if device_sys_df.empty:
            return device_sys_df

        usage_stats = device_sys_df[device_sys_df["type"].str.contains(
            "|".join(resources)).any(level=0)]

        df_grouped = (usage_stats.groupby(
            ["type", "nodeID", "ranges"])["ranges"].describe().reset_index())
        df_grouped = df_grouped.drop(["unique", "top", "freq"], axis="columns")
        df_grouped = (df_grouped.set_index(
            ["type", "nodeID"]).pivot(columns="ranges")["count"].reset_index())
        df_grouped = df_grouped.fillna(0)
        return df_grouped
    def load_python_profile_stats(self):
        """Load the stats in by creating the profile directory, downloading each stats directory from s3 to the
        profile directory, parsing the metadata from each stats directory name and creating a StepPythonProfileStats
        entry corresponding to the stats file in the stats directory.

        For cProfile, the stats file name is `python_stats`.
        For pyinstrument, the stats file name `python_stats.json`.
        """
        python_profile_stats = []

        self._set_up_profile_dir()

        list_request = ListRequest(Bucket=self.bucket_name, Prefix=self.prefix)
        s3_filepaths = S3Handler.list_prefix(list_request)
        object_requests = [
            ReadObjectRequest(
                os.path.join("s3://", self.bucket_name, s3_filepath))
            for s3_filepath in s3_filepaths
        ]
        objects = S3Handler.get_objects(object_requests)

        for full_s3_filepath, object_data in zip(s3_filepaths, objects):
            if os.path.basename(full_s3_filepath) not in (
                    CPROFILE_STATS_FILENAME,
                    PYINSTRUMENT_JSON_FILENAME,
                    PYINSTRUMENT_HTML_FILENAME,
            ):
                get_logger().info(
                    f"Unknown file {full_s3_filepath} found, skipping...")
                continue

            path_components = full_s3_filepath.split("/")
            framework, profiler_name, node_id, stats_dir, stats_file = path_components[
                -5:]

            stats_dir_path = os.path.join(self.profile_dir, node_id, stats_dir)
            os.makedirs(stats_dir_path, exist_ok=True)
            stats_file_path = os.path.join(stats_dir_path, stats_file)

            with open(stats_file_path, "wb") as f:
                f.write(object_data)

            python_profile_stats.append(
                StepPythonProfileStats(framework, profiler_name, node_id,
                                       stats_dir, stats_file_path))
        python_profile_stats.sort(
            key=lambda x: (x.start_time_since_epoch_in_micros, x.node_id)
        )  # sort each step's stats by the step number, then node ID.
        return python_profile_stats
    def __init__(self, rule_name, message_type, message_endpoint):
        self._topic_name = "SMDebugRules"
        self._logger = get_logger()

        if message_type == "sms" or message_type == "email":
            self._protocol = message_type
        else:
            self._protocol = None
            self._logger.info(
                f"Unsupported message type:{message_type} in MessageAction. Returning"
            )
            return
        self._message_endpoint = message_endpoint

        # Below 2 is to help in tests
        self._last_send_mesg_response = None
        self._last_subscription_response = None

        env_region_name = os.getenv("AWS_REGION", "us-east-1")

        self._sns_client = boto3.client("sns", region_name=env_region_name)

        self._topic_arn = self._create_sns_topic_if_not_exists()

        self._subscribe_mesgtype_endpoint()
        self._logger.info(
            f"Registering messageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} region:{env_region_name}"
        )
        self._rule_name = rule_name
示例#8
0
def set_up_logging_and_error_handling_agent(out_dir, stack_trace_filepath):
    """
    Set up each test to:
        - Add a logging handler to write all logs to a file (which will be used to verify caught errors in the tests)
        - Remove the duplicate logging filter
        - Reset the error handling agent after the test so that smdebug is reenabled.
    """
    old_create_from_json = Hook.create_from_json_file
    del_hook()

    logger = get_logger()
    os.makedirs(out_dir)
    file_handler = logging.FileHandler(filename=stack_trace_filepath)
    logger.addHandler(file_handler)
    duplicate_log_filter = None
    for log_filter in logger.filters:
        if isinstance(log_filter, DuplicateLogFilter):
            duplicate_log_filter = log_filter
            break
    logger.removeFilter(duplicate_log_filter)

    yield

    Hook.create_from_json_file = old_create_from_json
    error_handling_agent.disable_smdebug = False
    error_handling_agent.hook = None

    logger.removeHandler(file_handler)
    logger.addFilter(duplicate_log_filter)
示例#9
0
    def __init__(self,
                 bucket_name,
                 key_name,
                 aws_access_key_id=None,
                 aws_secret_access_key=None,
                 binary=True):
        super().__init__()
        self.bucket_name = bucket_name
        # S3 is not like a Unix file system where multiple slashes are normalized to one
        self.key_name = re.sub("/+", "/", key_name)
        self.binary = binary
        self._init_data()
        self.flushed = False
        self.logger = get_logger()

        self.current_len = 0
        self.s3 = boto3.resource("s3", region_name=get_region())
        self.s3_client = boto3.client("s3", region_name=get_region())

        # Set the desired multipart threshold value (5GB)
        MB = 1024**2
        self.transfer_config = TransferConfig(multipart_threshold=5 * MB)

        # check if the bucket exists
        buckets = [
            bucket["Name"]
            for bucket in self.s3_client.list_buckets()["Buckets"]
        ]
        if self.bucket_name not in buckets:
            self.s3_client.create_bucket(ACL="private",
                                         Bucket=self.bucket_name)
示例#10
0
    def __init__(
        self, bucket_name, key_name, aws_access_key_id=None, aws_secret_access_key=None, binary=True
    ):
        super().__init__()
        self.bucket_name = bucket_name
        # S3 is not like a Unix file system where multiple slashes are normalized to one
        self.key_name = re.sub("/+", "/", key_name)
        self.binary = binary
        self._init_data()
        self.flushed = False
        self.logger = get_logger()

        self.current_len = 0
        self.s3 = boto3.resource("s3", region_name=get_region())
        self.s3_client = boto3.client("s3", region_name=get_region())

        # Set the desired multipart threshold value (5GB)
        MB = 1024 ** 2
        self.transfer_config = TransferConfig(multipart_threshold=5 * MB)

        # Create bucket if does not exist
        try:
            self.s3_client.head_bucket(Bucket=self.bucket_name)
        except botocore.exceptions.ClientError:
            self.s3_client.create_bucket(ACL="private", Bucket=self.bucket_name)
 def __init__(self):
     self._processes = dict()
     self._trace_events = list()
     self._start_timestamp = 0
     self._start_time_known = False
     # The timestamp in trace events are in micro seconds, we multiply by 1000 to convert to ns
     self._timescale_multiplier_for_ns = 1000
     self.logger = get_logger("smdebug-profiler")
示例#12
0
 def __init__(self, actions_str, rule_name):
     self._actions = []
     self._logger = get_logger()
     actions_str = actions_str.strip() if actions_str is not None else ""
     if actions_str == "":
         self._logger.info(f"No action specified for rule {rule_name}.")
         return
     self._register_actions(actions_str, rule_name)
示例#13
0
    def _dump_info_to_json(self, training_info, trace_json_file):
        """
        This function dumps the training info gathered into the
        json file passed.
        """
        with open(trace_json_file, "r+") as f:
            data = json.load(f)
        f.close()

        for phase, metrics in training_info.items():
            if not metrics:
                get_logger().error(f"No metrics captured after profiling for {phase}!")
                continue

            # Getting the min start_time to get the start_time
            start = min(x[1] for x in metrics)
            # Calculating the max end time using duration.
            end = max(x[1] + x[2] for x in metrics)
            phase = "BackwardPass" if phase != "ForwardPass" else phase
            main_entry = {
                "pid": "/" + phase,
                "tid": phase,
                "ph": "X",
                "ts": start / 1000,
                "dur": (end - start) / 1000,
                "name": phase,
                "args": {"group_id": phase, "long_name": phase},
            }
            data["traceEvents"].append(main_entry)

            for idx, metrics in enumerate(metrics):
                entry = {
                    "pid": "/" + phase,
                    "tid": phase + "ops",
                    "ph": "X",
                    "args": {"group_id": phase, "long_name": metrics[0]},
                    "ts": metrics[1] / 1000,
                    "dur": metrics[2] / 1000,
                    "name": metrics[0],
                }
                data["traceEvents"].append(entry)

        get_logger().info(f"Dumping into file {trace_json_file}")
        with open(trace_json_file, "w+") as outfile:
            json.dump(data, outfile)
示例#14
0
 def __init__(self):
     """Initialize the parser to be disabled for profiling and detailed profiling.
     """
     self.last_json_config = None
     self.config = None
     self.profiling_enabled = False
     self.logger = get_logger()
     self.last_logging_statuses = defaultdict(lambda: False)
     self.current_logging_statuses = defaultdict(lambda: False)
     self.load_config()
 def __init__(self, rule_name, training_job_prefix):
     self._training_job_prefix = training_job_prefix
     env_region_name = os.getenv("AWS_REGION", "us-east-1")
     self._logger = get_logger()
     self._logger.info(
         f"StopTrainingAction created with training_job_prefix:{training_job_prefix} and region:{env_region_name}"
     )
     self._sm_client = boto3.client("sagemaker",
                                    region_name=env_region_name)
     self._rule_name = rule_name
     self._found_jobs = self._get_sm_tj_jobs_with_prefix()
示例#16
0
 def __init__(self,
              path,
              index_writer=None,
              verbose=False,
              write_checksum=False):
     self._filename = path
     self.tfrecord_writer = None
     self.verbose = verbose
     self._num_outstanding_events = 0
     self._logger = get_logger()
     self.write_checksum = write_checksum
     self.index_writer = index_writer
示例#17
0
 def __init__(self, path, mode):
     super().__init__()
     self.path = path
     self.mode = mode
     self.logger = get_logger()
     ensure_dir(path)
     if mode in WRITE_MODES:
         self.temp_path = get_temp_path(self.path)
         ensure_dir(self.temp_path)
         self.open(self.temp_path, mode)
     else:
         self.open(self.path, mode)
示例#18
0
    def __init__(self, base_trial, other_trials=None):
        self.base_trial = base_trial
        self.other_trials = other_trials

        self.trials = [base_trial]
        if self.other_trials is not None:
            self.trials += [x for x in self.other_trials]

        self.req_tensors = RequiredTensors(self.base_trial, self.other_trials)

        self.logger = get_logger()
        self.rule_name = self.__class__.__name__
    def __init__(self):
        self._callable_fn_cache = {}  # Maps fetches to callable_fn

        cache_type = os.getenv(CALLABLE_CACHE_ENV_VAR, DEFAULT_CALLABLE_CACHE)

        if cache_type == CacheType.CACHE_PER_MODE.name:
            self.cache_type = CacheType.CACHE_PER_MODE
        elif cache_type == CacheType.CLEAR_FOR_EACH_MODE.name:
            self.cache_type = CacheType.CLEAR_FOR_EACH_MODE
        else:
            self.cache_type = CacheType.OFF

        logger.get_logger().debug(
            f"Created callable_fn cache of type {self.cache_type.name}")

        if self.cache_type == CacheType.CACHE_PER_MODE:
            # create callable cache per mode
            for mode in ALLOWED_MODES:
                self._callable_fn_cache[mode] = {}
        else:
            # cleared cache at the end of each mode
            self._callable_fn_cache = {}
    def load_python_profile_stats(self):
        """Load the stats in by scanning each stats directory in the profile directory, parsing the metadata from the
        stats directory name and creating a StepPythonProfileStats entry corresponding to the stats file in the
        stats directory.

        For cProfile, the stats file name is `python_stats`.
        For pyinstrument, the stats file name `python_stats.json` or `python_stats.html`.
        """
        python_profile_stats = []
        framework = os.path.basename(os.path.dirname(self.profile_dir))
        for node_id in os.listdir(self.profile_dir):
            node_dir_path = os.path.join(self.profile_dir, node_id)
            for stats_dir in os.listdir(node_dir_path):
                stats_dir_path = os.path.join(node_dir_path, stats_dir)
                for filename in os.listdir(stats_dir_path):
                    if filename == CPROFILE_STATS_FILENAME:
                        profiler_name = CPROFILE_NAME
                        stats_file_path = os.path.join(
                            stats_dir_path, CPROFILE_STATS_FILENAME)
                    elif filename == PYINSTRUMENT_JSON_FILENAME:
                        profiler_name = PYINSTRUMENT_NAME
                        stats_file_path = os.path.join(
                            stats_dir_path, PYINSTRUMENT_JSON_FILENAME)
                    elif filename == PYINSTRUMENT_HTML_FILENAME:
                        profiler_name = PYINSTRUMENT_NAME
                        stats_file_path = os.path.join(
                            stats_dir_path, PYINSTRUMENT_HTML_FILENAME)
                    else:
                        get_logger().info(
                            f"Unknown file {filename} found, skipping...")
                        continue
                    python_profile_stats.append(
                        StepPythonProfileStats(framework, profiler_name,
                                               node_id, stats_dir,
                                               stats_file_path))
        python_profile_stats.sort(
            key=lambda x: (x.start_time_since_epoch_in_micros, x.node_id)
        )  # sort each step's stats by the step number, then node ID.
        return python_profile_stats
示例#21
0
 def __init__(self, type=""):
     # list of ProcessInfo found in this file
     self._processes = dict()
     # reverse mapping from name to id
     self._process_name_to_id = dict()
     self._trace_events = []
     """
     The _pid_stacks maintain the directory of stacks indexed using pid. The stack contains 'B' type events.
     The stack will be popped as we process the 'E' events for the same pid.
     """
     self._pid_stacks = dict()
     self._start_timestamp = 0
     self.type = type
     self.logger = get_logger()
 def __init__(self, framework=None):
     """Initialize the parser to be disabled for profiling and detailed profiling.
     """
     self.framework = framework
     self.last_json_config = None
     self.config = None
     self.profiling_enabled = False
     self.logger = get_logger()
     self.last_logging_statuses = defaultdict(lambda: False)
     self.current_logging_statuses = defaultdict(lambda: False)
     self.load_config()
     self.python_profiler = (
         PythonProfiler.get_python_profiler(self.config, self.framework)
         if self.is_python_profiling_enabled()
         else None
     )
 def __init__(self, num_retries=5, debug=False):
     # if you are creating an s3handler object in jupyter, ensure the nest_asyncio is applied
     check_notebook()
     self.loop = asyncio.get_event_loop()
     self.client = aioboto3.client("s3",
                                   loop=self.loop,
                                   region_name=get_region())
     self.num_retries = num_retries
     self.logger = get_logger()
     if debug:
         self.loop.set_debug(True)
         self.loop.slow_callback_duration = 4
         logging.basicConfig(level=logging.DEBUG)
         aioboto3.set_stream_logger(name="boto3",
                                    level=logging.DEBUG,
                                    format_string=None)
示例#24
0
def test_tf_device_name_serialize_and_deserialize():
    logger = get_logger()
    import tensorflow.compat.v1 as tf

    device_name = tf.test.gpu_device_name()
    if not bool(device_name):
        logger.warning(
            "There is no GPU Support on this machine. Please ignore the cuInit errors generated above"
        )
        device_name = "/device:GPU:0"

    serialized_device_name = serialize_tf_device(device_name)
    assert deserialize_tf_device(serialized_device_name) == device_name

    device_name = "/replica:0/task:0/device:GPU:0"
    serialized_device_name = serialize_tf_device(device_name)
    assert deserialize_tf_device(serialized_device_name) == device_name
    def __init__(self, use_in_memory_cache=False):
        self.logger = get_logger()

        self._event_parsers = []

        # This is a list of timestamp -> [event_file] mapping
        self._timestamp_to_filename = dict()

        # This is a set of parsed event files. The entry is made into this file only if the complete file is read.
        self._parsed_files = set()

        # The startAfter_prefix is used in ListPrefix call to poll for available tracefiles in the S3 bucket. The
        # prefix lags behind the last polled tracefile by tunable trailing duration. This is to ensure that we do not
        # miss a
        # tracefile corresponding to timestamp earlier than last polled timestamp but arrived after we had polled.

        self._startAfter_prefix = ""
        self.prefix = ""
        self._cache_events_in_memory = use_in_memory_cache
示例#26
0
    def __init__(self, base_trial, action_str, other_trials=None):
        self.base_trial = base_trial
        self.other_trials = other_trials

        self.trials = [base_trial]
        if self.other_trials is not None:
            self.trials += [x for x in self.other_trials]

        self.req_tensors = RequiredTensors(self.base_trial, self.other_trials)

        self.logger = get_logger()
        self.rule_name = self.__class__.__name__
        self._actions = Actions(action_str, rule_name=self.rule_name)
        self.report = {
            "RuleTriggered": 0,
            "Violations": 0,
            "Details": {},
            "Datapoints": 0,
            "RuleParameters": "",
        }
示例#27
0
    def __init__(self, name, trial_dir, output_dir):
        self.name = name
        # Trial dir is the s3/local directory contains profiling data captured during runtime.
        self.path = trial_dir

        self.logger = get_logger()
        self.first_timestamp = 0
        self.get_first_timestamp()

        # Output directory will contains data emitted by rules further published to S3.
        self.output_dir = output_dir
        if output_dir and not os.path.exists(output_dir):
            pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

        # .sagemaker-ignore will not be picked by service code for uploading. It will be used to save temp files.
        self.temp_dir = os.path.join(output_dir, ".sagemaker-ignore")
        if not os.path.exists(self.temp_dir):
            pathlib.Path(self.temp_dir).mkdir(parents=True, exist_ok=True)
        self.logger.info(
            "Output files of ProfilerTrial will be saved to {}".format(
                self.output_dir))
import os
import threading
import time
from datetime import datetime

# Third Party
import six

# First Party
from smdebug.core.access_layer.file import SMDEBUG_TEMP_PATH_SUFFIX
from smdebug.core.locations import TraceFileLocation
from smdebug.core.logger import get_logger
from smdebug.core.utils import ensure_dir, get_node_id
from smdebug.profiler.profiler_constants import CONVERT_TO_MICROSECS, PYTHONTIMELINE_SUFFIX

logger = get_logger()


def _get_sentinel_event(base_start_time):
    """Generate a sentinel trace event for terminating worker."""
    return TimelineRecord(timestamp=time.time(),
                          base_start_time=base_start_time)


"""
TimelineRecord represents one trace event that ill be written into a trace event JSON file.
"""


class TimelineRecord:
    def __init__(
示例#29
0
 def __init__(self, path):
     self.event_file_retry_limit = int(
         os.getenv(MISSING_EVENT_FILE_RETRY_LIMIT_KEY, MISSING_EVENT_FILE_RETRY_LIMIT)
     )
     self.path = path
     self.logger = get_logger()
示例#30
0
    def __init__(self,
                 name,
                 range_steps=None,
                 parallel=True,
                 check=False,
                 index_mode=True,
                 cache=False):
        self.name = name
        self._tensors = {}

        # nested dictionary from mode -> mode_step -> global_step
        # will not have global mode as a key
        self._mode_to_global = {}

        # dictionary from global_step -> (mode, mode_step)
        # can have global mode as a value
        self._global_to_mode = {}

        self.logger = get_logger()
        self.parallel = parallel
        self.check = check
        self.range_steps = range_steps
        self.collection_manager = None
        self.loaded_all_steps = False
        self.cache = cache
        self.path = None
        self.index_reader = None
        self.index_tensors_dict = {}
        self.index_mode = index_mode
        self.last_event_token = None
        self.last_index_token = None
        self.worker_set = set()
        self.global_step_to_tensors_map = dict()
        self.mode_to_tensors_map = dict()
        self.num_workers = 0
        self.workers_for_global_step = {}
        self.last_complete_step = -1
        """
        INCOMPLETE_STEP_WAIT_WINDOW defines the maximum number
        of incomplete steps that the trial will wait for before marking
        half of them as complete.
        """
        self._incomplete_wait_for_step_window = int(
            os.getenv(INCOMPLETE_STEP_WAIT_WINDOW_KEY,
                      INCOMPLETE_STEP_WAIT_WINDOW_DEFAULT))

        # this is turned off during rule invocation for performance reasons since
        # required tensors are already fetched
        self.dynamic_refresh = True
        # number of seconds to wait before refreshing after seeing end of trial
        self._training_end_delay_refresh = int(
            os.getenv(TRAINING_END_DELAY_REFRESH_KEY,
                      TRAINING_END_DELAY_REFRESH_DEFAULT))

        if self.range_steps is not None:
            assert self.range_steps[0] is None or (isinstance(
                self.range_steps[0], int) and self.range_steps[0] >= 0)
            assert self.range_steps[1] is None or (isinstance(
                self.range_steps[1], int) and self.range_steps[1] >= 0)
            if self.range_steps[1] is not None and self.range_steps[
                    0] is not None:
                assert int(self.range_steps[1]) > int(self.range_steps[0]), (
                    "range_steps should be of the form "
                    "(begin, end) where begin is less than end")
            if self.range_steps[0] is not None and self.range_steps[
                    1] is not None:
                self.logger.info(
                    "Trial {} will look for steps between {} and {}".format(
                        self.name, self.range_steps[0], self.range_steps[1]))