def __init__(self, config: dict, model: torch.nn.Module) -> None: self.logger = logging.getLogger(__name__) self.logger.info("started") self.config = config self.model = model method = config[TRAINING][LOSS_CRITERION_CONFIG]["method"] # get loss criterion like Inferno does for isinstance(method, str): # Look for criteria in torch criterion_class = getattr(torch.nn, method, None) if criterion_class is None: # Look for it in inferno extensions criterion_class = getattr(criteria, method, None) if criterion_class is None: raise ValueError(f"Criterion {method} not found.") self.criterion_class = criterion_class self.shutdown_event = threading.Event() self.training_shape = None self.valid_shapes = None self.shrinkage: Optional[Point] = None self.dry_run_queue = queue.Queue() self.dry_run_thread = threading.Thread(target=add_logger(self.logger)(self._dry_run_worker), name="DryRun") self.dry_run_thread.start()
def add_devices(self, devices: Set[torch.device]) -> None: self.logger.debug("add devices %s", devices) for d in devices: assert d not in self.devices assert d not in self.shutdown_worker_events assert d not in self.forward_worker_threads self.shutdown_worker_events[d] = threading.Event() self.forward_worker_threads[d] = threading.Thread( target=add_logger(self.logger)(self._forward_worker), name=f"ForwardWorker({d})", kwargs={"device": d} ) self.forward_worker_threads[d].start() self.devices.update(devices)
def __init__(self, config: dict, model: torch.nn.Module) -> None: self.logger = logging.getLogger(__name__) self.logger.info("started") self.config = config self.training_model = model self.shutdown_event = threading.Event() self.batch_size: int = config.get(INFERENCE_BATCH_SIZE, None) self.forward_queue = queue.Queue() self.shutdown_worker_events = {} self.forward_worker_threads = {} self.devices = set() # self.add_devices({torch.device("cpu")}) self.device_change_queue = queue.Queue() self.device_setter_thread = threading.Thread(target=add_logger( self.logger)(self._device_setter_worker), name="DeviceSetter") self.device_setter_thread.start()
def __init__( self, config: dict, model_file: bytes, model_state: bytes, optimizer_state: bytes, log_queue: Optional[mp.Queue] = None, ) -> None: """ :param config: configuration dict :param model_file: bytes of file describing the neural network model :param model_state: binarized model state dict :param optimizer_state: binarized optimizer state dict """ assert model_file for required in [MODEL_CLASS_NAME]: if required not in config: raise ValueError(f"{required} missing in config") self.config = config self.shutdown_event = threading.Event() self.logger = logging.getLogger(__name__) self.logger.info("started") self.valid_shapes: Optional[List[Point]] = None self.shrinkage: Optional[Point] = None self.idle_devices: List[torch.device] = [] self.training_devices: List[torch.device] = [] self.inference_devices: List[torch.device] = [] self.tempdir = tempfile.mkdtemp() user_module_name = "usermodel" with open(os.path.join(self.tempdir, user_module_name + ".py"), "wb") as f: f.write(model_file) sys.path.insert(0, self.tempdir) user_module = importlib.import_module(user_module_name) self.model: torch.nn.Module = getattr( user_module, self.config[MODEL_CLASS_NAME])( **self.config.get(MODEL_INIT_KWARGS, {})) self.logger.debug("created user model") if model_state: self.logger.debug("load model state") try: self.model.load_state_dict( torch.load(io.BytesIO(model_state), map_location="cpu")) except Exception as e: self.logger.exception(e) else: self.logger.info("restored model state") try: self.logger.debug("start dryrun process") handler2dryrun_conn, dryrun2handler_conn = mp.Pipe() self._dry_run_proc = mp.Process( name="DryRun", target=run_dryrun, kwargs={ "conn": dryrun2handler_conn, "config": config, "model": self.model, "log_queue": log_queue }, ) self._dry_run_proc.start() self._dry_run: IDryRun = create_client(IDryRun, handler2dryrun_conn) self.logger.debug("start training process") handler2training_conn, training2handler_conn = mp.Pipe() self._training_proc = mp.Process( target=run_training, name="Training", kwargs={ "conn": training2handler_conn, "config": config, "model": self.model, "optimizer_state": optimizer_state, "log_queue": log_queue, }, ) self._training_proc.start() self._training: ITraining = create_client(ITraining, handler2training_conn) self.logger.debug("start inference process") handler2inference_conn, inference2handler_conn = mp.Pipe() self._inference_proc = mp.Process( target=run_inference, name="Inference", kwargs={ "conn": inference2handler_conn, "config": config, "model": self.model, "log_queue": log_queue }, ) self._inference_proc.start() self._inference: IInference = create_client( IInference, handler2inference_conn) # start device setter thread that will wait for dry run processes to finish self.new_device_names: queue.Queue = queue.Queue() self.device_setter_thread = threading.Thread(target=add_logger( self.logger)(self._device_setter_worker), name="DeviceSetter") self.device_setter_thread.start() except Exception as e: self.logger.exception(e) self.shutdown()
def __init__(self, config: dict, model: torch.nn.Module, optimizer_state: bytes = b""): self.logger = logging.getLogger(__name__) self.logger.info("started") self.shutdown_event = threading.Event() self.idle = False self.common_model = model self.model = copy.deepcopy(model) self.optimizer_state = optimizer_state self.training_settings_lock = threading.Lock() # self.devices = [torch.device("cpu")] self.devices = [] self.base_device = "cpu" training_transform = Compose( *[ get_transform(name, **kwargs) for name, kwargs in config[TRAINING].get(TRANSFORMS, {"Normalize": {"apply_to": [0]}}).items() ] ) validation_transform = Compose( *[ get_transform(name, **kwargs) for name, kwargs in config[VALIDATION].get(TRANSFORMS, {"Normalize": {"apply_to": [0]}}).items() ] ) self.datasets = { TRAINING: DynamicDataset(transform=training_transform), VALIDATION: DynamicDataset(transform=validation_transform), } self.update_loader = {TRAINING: True, VALIDATION: True} self.loader_kwargs = { TRAINING: {"dataset": self.datasets[TRAINING]}, VALIDATION: {"dataset": self.datasets[VALIDATION]}, } for key, default in self.trainer_defaults.items(): if key not in config[TRAINING]: config[TRAINING][key] = default self.config = config self._pause_event = threading.Event() self._pause_event.set() self.update_trainer_event = threading.Event() self.update_trainer_event.set() self.trainer = TikTrainer.build( model=self.model, break_events=[self.shutdown_event, self._pause_event, self.update_trainer_event], **self.create_trainer_config(), ) log_dir = self.config.get(LOGGING, {}).get(DIRECTORY, "") if os.path.exists(log_dir): log_dir = os.path.join(log_dir, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) os.makedirs(log_dir, exist_ok=True) self.trainer.build_logger( TensorboardLogger, log_directory=log_dir, log_scalars_every=(1, "iteration"), log_images_every=(1, "epoch"), ) self.trainer.register_callback(self.end_of_training_iteration, trigger="end_of_training_iteration") self.trainer.register_callback(self.end_of_validation_iteration, trigger="end_of_validation_iteration") self.trainer._iteration_count = self.config[TRAINING].get(NUM_ITERATIONS_DONE, 0) if self.optimizer_state: optimizer = self.create_optimizer(self.optimizer_state) if optimizer is not None: self.trainer.build_optimizer(optimizer) self.training_thread = threading.Thread(target=add_logger(self.logger)(self._training_worker), name="Training") self.training_thread.start()