示例#1
0
文件: __init__.py 项目: bdtmnk/labml
    def start(self,
              *,
              run_uuid: Optional[str] = None,
              checkpoint: Optional[int] = None):
        if run_uuid is not None:
            if checkpoint is None:
                checkpoint = -1
            global_step = self.__start_from_checkpoint(run_uuid, checkpoint)
        else:
            global_step = 0

        self.run.start_step = global_step

        self._start_tracker()
        tracker().set_start_global_step(global_step)

        if self.distributed_rank == 0:
            self.__print_info()
            if self.check_repo_dirty and self.run.is_dirty:
                logger.log([
                    ("[FAIL]", Text.danger),
                    " Cannot trial an experiment with uncommitted changes."
                ])
                exit(1)

        if not self.is_evaluate:
            if self.distributed_rank == 0:
                from labml.internal.computer.configs import computer_singleton
                computer_singleton().add_project(lab_singleton().path)

                self.run.save_info()
            self._save_pid()

            if self.distributed_rank == 0:
                if self.configs_processor is not None:
                    self.configs_processor.add_saver(
                        FileConfigsSaver(self.run.configs_path))

                if self.web_api is not None:
                    self.web_api.start(self.run)
                    if self.configs_processor is not None:
                        self.configs_processor.add_saver(
                            self.web_api.get_configs_saver())
                        self.web_api.set_dynamic_handler(
                            ExperimentDynamicUpdateHandler(
                                self.configs_processor))

                if self.wandb is not None:
                    self.wandb.init(self.run.name, self.run.run_path)
                    if self.configs_processor is not None:
                        self.configs_processor.add_saver(
                            self.wandb.get_configs_saver())

                tracker().save_indicators(self.run.indicators_path)

        self.is_started = True
        return ExperimentWatcher(self)
示例#2
0
    def start(self,
              *,
              run_uuid: Optional[str] = None,
              checkpoint: Optional[int] = None):
        if run_uuid is not None:
            if checkpoint is None:
                checkpoint = -1
            global_step = self.__start_from_checkpoint(run_uuid, checkpoint)
        else:
            global_step = 0

        self.run.start_step = global_step

        self._start_tracker()
        tracker().set_start_global_step(global_step)

        if self.distributed_rank == 0:
            self.__print_info()
            if self.check_repo_dirty and self.run.is_dirty:
                logger.log([
                    ("[FAIL]", Text.danger),
                    " Cannot trial an experiment with uncommitted changes."
                ])
                exit(1)

        if not self.is_evaluate:
            if self.distributed_rank == 0:
                self.run.save_info()
            self._save_pid()

            if self.distributed_rank == 0:
                if self.configs_processor is not None:
                    self.configs_processor.add_saver(
                        FileConfigsSaver(self.run.configs_path))

                if self.web_api is not None:
                    self.web_api.set_info(run_uuid=self.run.uuid,
                                          name=self.run.name,
                                          comment=self.run.comment)
                    self.web_api.start()
                    if self.configs_processor is not None:
                        self.configs_processor.add_saver(
                            self.web_api.get_configs_saver())

                tracker().save_indicators(self.run.indicators_path)

                # PERF: Writing to tensorboard takes about 4 seconds
                # Also wont work when configs are updated live
                # if self.configs_processor:
                #     tracker().write_h_parameters(self.configs_processor.get_hyperparams())

        self.is_started = True
        return ExperimentWatcher(self)
示例#3
0
    def finish(self, status: str, details: any = None):
        if not self.is_evaluate:
            with open(str(self.run.run_log_path), 'a') as f:
                end_time = time.time()
                data = json.dumps({'status': status,
                                   'rank': self.distributed_rank,
                                   'details': details,
                                   'time': end_time}, indent=None)
                f.write(data + '\n')

        tracker().finish_loop()

        if self.web_api is not None:
            self.web_api.status(self.distributed_rank, status, details, end_time)
示例#4
0
    def save_checkpoint(self):
        if self.is_evaluate:
            return
        if self.distributed_rank != 0:
            return

        self.checkpoint_saver.save(tracker().global_step)
示例#5
0
    def _start_tracker(self):
        tracker().reset_writers()

        if self.is_evaluate:
            return

        if self.distributed_rank != 0:
            return

        if 'screen' in self.writers:
            from labml.internal.tracker.writers import screen
            tracker().add_writer(screen.ScreenWriter())

        if 'sqlite' in self.writers:
            from labml.internal.tracker.writers import sqlite
            tracker().add_writer(sqlite.Writer(self.run.sqlite_path, self.run.artifacts_folder))

        if 'tensorboard' in self.writers:
            from labml.internal.tracker.writers import tensorboard
            tracker().add_writer(tensorboard.Writer(self.run.tensorboard_log_path))

        if 'file' in self.writers:
            from labml.internal.tracker.writers import file
            tracker().add_writer(file.Writer(self.run.log_file))

        if 'web_api' in self.writers:
            web_api_conf = lab_singleton().web_api
            if web_api_conf is not None:
                from labml.internal.tracker.writers import web_api
                from labml.internal.api import ApiCaller
                from labml.internal.api.experiment import ApiExperiment
                api_caller = ApiCaller(web_api_conf.url,
                                       {'run_uuid': self.run.uuid},
                                       timeout_seconds=15)
                self.web_api = ApiExperiment(api_caller,
                                             frequency=web_api_conf.frequency,
                                             open_browser=web_api_conf.open_browser)
                tracker().add_writer(web_api.Writer(api_caller,
                                                    frequency=web_api_conf.frequency))
        else:
            self.web_api = None