예제 #1
0
 def fhook(self, module, inputs, outputs):
     # we would stop profiling and restart from this phase
     if python_profiler:
         python_profiler.stop_profiling(
             StepPhase.FORWARD_PASS_END,
             end_mode=mode_keys_to_python_profile_mode(self.mode),
             end_step=self.step,
         )
         if self.profiler_config_parser.should_save_metrics(
                 MetricsCategory.PYTHON_PROFILING, self.step):
             python_profiler.start_profiling(
                 StepPhase.FORWARD_PASS_END,
                 start_mode=mode_keys_to_python_profile_mode(self.mode),
                 start_step=self.step,
             )
예제 #2
0
 def close(self):
     self._cleanup()
     if python_profiler:
         python_profiler.start_profiling(
             StepPhase.STEP_END,
             start_mode=mode_keys_to_python_profile_mode(self.mode),
             start_step=self.mode_steps[self.mode],
         )
    def _handle_step_python_profiling(self, step_phase: StepPhase, mode: ModeKeys, current_step):
        """Handle python profiling at the given step phase, mode and step by stopping python profiling and
        starting python profiling again if python profiling is enabled and python profiling stats should be saved
        for the current step.
        """
        if not self.is_python_profiling_enabled():
            return

        self.python_profiler.stop_profiling(
            step_phase, end_mode=mode_keys_to_python_profile_mode(mode), end_step=current_step
        )

        if self.should_save_metrics(MetricsCategory.PYTHON_PROFILING, current_step):
            self.python_profiler.start_profiling(
                step_phase,
                start_mode=mode_keys_to_python_profile_mode(mode),
                start_step=current_step,
            )
예제 #4
0
    def forward_pre_hook(self, module, inputs):
        # Write the gradients of the past step if the writer is still available.
        if self.writer is not None:
            self._close_writers()
        self._close_tb_writer()

        if not self.prepared_collections:
            # at this point we need all collections to be ready
            # this may not be the case at creation of hook
            # as user's code after hook might add collections
            self._prepare_collections()
            self.prepared_collections = True

        self._increment_step()

        ## prepararing for step metrics
        # last operation can be forward( eval loop is running or multiple forward for example RNN can have multiple call to forward of module)
        # or last operation can be backward (train backward loop just finished and we are at forward again)

        # we will log all outstanding forward and backward events
        self.log_outstanding_timeline_metrics()

        self.step_event = self._TraceEventData(
            phase="Step:" + str(self.mode),
            op_name="Step:" + str(self.mode),
            start_time=time.time(),
            dur=0,  # end time of step_event will be updated every time a forward event or backward is called after this
            pid=os.getpid(),
            step_num=str(self.mode_steps[self.mode]),
        )
        self.parent_forward_event = self._TraceEventData(
            phase="Forward",
            op_name=module._module_name,
            start_time=time.time(),
            dur=0,  # end time of parent_forward_event will be updated every time a forward event is called after this
            pid=os.getpid(),
            step_num=str(self.mode_steps[self.mode]),
        )

        self.profiler_config_parser.load_config()

        # Disable python profiling if the python profiler is currently profiling.
        if python_profiler:
            python_profiler.stop_profiling(
                StepPhase.STEP_START,
                end_mode=mode_keys_to_python_profile_mode(self.mode),
                end_step=self.step,
            )
            python_profiler.stop_profiling(StepPhase.STEP_START, self.step)
            if self.profiler_config_parser.should_save_metrics(
                MetricsCategory.PYTHON_PROFILING, self.step
            ):
                python_profiler.start_profiling(
                    StepPhase.STEP_START,
                    start_mode=mode_keys_to_python_profile_mode(self.mode),
                    start_step=self.step,
                )

        if self.autograd_profiler_enabled:
            self._collect_torch_profiling_data_if_profiler_enabled()

        # should we re-enable profiling for this step?
        if (
            self.profiler_config_parser.should_save_metrics(
                MetricsCategory.DETAILED_PROFILING, self.step
            )
            and not self.autograd_profiler_enabled
        ):
            self.autograd_profiler_enabled = True
            if is_pt_1_5():
                torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False))
                self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
            elif is_pt_1_7():
                torch.autograd._enable_profiler(
                    torch.autograd.ProfilerConfig(self.profiler, False, False, False)
                )
                self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
            elif is_pt_1_8():
                torch.autograd._enable_profiler_legacy(
                    torch.autograd.ProfilerConfig(self.profiler, False, False, False, False)
                )
                self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
            elif is_pt_1_6():
                torch.autograd._enable_profiler(
                    torch.autograd.ProfilerConfig(self.profiler, False, False)
                )
                self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
            else:
                self.logger.warn(
                    f"The detailed profiling using autograd profiler is not supported for torch version "
                    f"{torch.__version__}"
                )
                self.autograd_profiler_enabled = False

        if self.is_smdataparallel_profiling:
            # Stop smdataparallel profiling at end step
            stop_smdataparallel_profiler(
                smdataparallel, self.profiler_config_parser.config.local_path
            )
        self.is_smdataparallel_profiling = False
        if self.profiler_config_parser.should_save_metrics(
            MetricsCategory.SMDATAPARALLEL_PROFILING, self.step
        ):
            start_smdataparallel_profiler(
                smdataparallel, self.profiler_config_parser.config.local_path
            )
            self.is_smdataparallel_profiling = True

        if self._get_collections_to_save_for_step():
            self._initialize_writers()
            self._log_params(module)

        if self.last_saved_step is not None and not self.exported_collections:
            self.export_collections()
            self.exported_collections = True

        self.first_forward_submodule_name = None