Ejemplo n.º 1
0
    def __init__(self, config: dict, model: torch.nn.Module) -> None:
        self.logger = logging.getLogger(__name__)
        self.logger.info("started")
        self.config = config
        self.model = model

        method = config[TRAINING][LOSS_CRITERION_CONFIG]["method"]
        # get loss criterion like Inferno does for isinstance(method, str):
        # Look for criteria in torch
        criterion_class = getattr(torch.nn, method, None)
        if criterion_class is None:
            # Look for it in inferno extensions
            criterion_class = getattr(criteria, method, None)
        if criterion_class is None:
            raise ValueError(f"Criterion {method} not found.")

        self.criterion_class = criterion_class

        self.shutdown_event = threading.Event()

        self.training_shape = None
        self.valid_shapes = None
        self.shrinkage: Optional[Point] = None

        self.dry_run_queue = queue.Queue()
        self.dry_run_thread = threading.Thread(target=add_logger(self.logger)(self._dry_run_worker), name="DryRun")
        self.dry_run_thread.start()
Ejemplo n.º 2
0
    def add_devices(self, devices: Set[torch.device]) -> None:
        self.logger.debug("add devices %s", devices)
        for d in devices:
            assert d not in self.devices
            assert d not in self.shutdown_worker_events
            assert d not in self.forward_worker_threads
            self.shutdown_worker_events[d] = threading.Event()
            self.forward_worker_threads[d] = threading.Thread(
                target=add_logger(self.logger)(self._forward_worker), name=f"ForwardWorker({d})", kwargs={"device": d}
            )
            self.forward_worker_threads[d].start()

        self.devices.update(devices)
Ejemplo n.º 3
0
    def __init__(self, config: dict, model: torch.nn.Module) -> None:
        self.logger = logging.getLogger(__name__)
        self.logger.info("started")
        self.config = config
        self.training_model = model

        self.shutdown_event = threading.Event()

        self.batch_size: int = config.get(INFERENCE_BATCH_SIZE, None)
        self.forward_queue = queue.Queue()
        self.shutdown_worker_events = {}
        self.forward_worker_threads = {}
        self.devices = set()
        # self.add_devices({torch.device("cpu")})

        self.device_change_queue = queue.Queue()
        self.device_setter_thread = threading.Thread(target=add_logger(
            self.logger)(self._device_setter_worker),
                                                     name="DeviceSetter")
        self.device_setter_thread.start()
Ejemplo n.º 4
0
    def __init__(
        self,
        config: dict,
        model_file: bytes,
        model_state: bytes,
        optimizer_state: bytes,
        log_queue: Optional[mp.Queue] = None,
    ) -> None:
        """
        :param config: configuration dict
        :param model_file: bytes of file describing the neural network model
        :param model_state: binarized model state dict
        :param optimizer_state: binarized optimizer state dict
        """
        assert model_file
        for required in [MODEL_CLASS_NAME]:
            if required not in config:
                raise ValueError(f"{required} missing in config")

        self.config = config

        self.shutdown_event = threading.Event()

        self.logger = logging.getLogger(__name__)
        self.logger.info("started")
        self.valid_shapes: Optional[List[Point]] = None
        self.shrinkage: Optional[Point] = None
        self.idle_devices: List[torch.device] = []
        self.training_devices: List[torch.device] = []
        self.inference_devices: List[torch.device] = []

        self.tempdir = tempfile.mkdtemp()
        user_module_name = "usermodel"
        with open(os.path.join(self.tempdir, user_module_name + ".py"),
                  "wb") as f:
            f.write(model_file)

        sys.path.insert(0, self.tempdir)
        user_module = importlib.import_module(user_module_name)

        self.model: torch.nn.Module = getattr(
            user_module, self.config[MODEL_CLASS_NAME])(
                **self.config.get(MODEL_INIT_KWARGS, {}))
        self.logger.debug("created user model")

        if model_state:
            self.logger.debug("load model state")
            try:
                self.model.load_state_dict(
                    torch.load(io.BytesIO(model_state), map_location="cpu"))
            except Exception as e:
                self.logger.exception(e)
            else:
                self.logger.info("restored model state")

        try:
            self.logger.debug("start dryrun process")
            handler2dryrun_conn, dryrun2handler_conn = mp.Pipe()
            self._dry_run_proc = mp.Process(
                name="DryRun",
                target=run_dryrun,
                kwargs={
                    "conn": dryrun2handler_conn,
                    "config": config,
                    "model": self.model,
                    "log_queue": log_queue
                },
            )
            self._dry_run_proc.start()
            self._dry_run: IDryRun = create_client(IDryRun,
                                                   handler2dryrun_conn)

            self.logger.debug("start training process")
            handler2training_conn, training2handler_conn = mp.Pipe()
            self._training_proc = mp.Process(
                target=run_training,
                name="Training",
                kwargs={
                    "conn": training2handler_conn,
                    "config": config,
                    "model": self.model,
                    "optimizer_state": optimizer_state,
                    "log_queue": log_queue,
                },
            )
            self._training_proc.start()
            self._training: ITraining = create_client(ITraining,
                                                      handler2training_conn)

            self.logger.debug("start inference process")
            handler2inference_conn, inference2handler_conn = mp.Pipe()
            self._inference_proc = mp.Process(
                target=run_inference,
                name="Inference",
                kwargs={
                    "conn": inference2handler_conn,
                    "config": config,
                    "model": self.model,
                    "log_queue": log_queue
                },
            )
            self._inference_proc.start()
            self._inference: IInference = create_client(
                IInference, handler2inference_conn)

            # start device setter thread that will wait for dry run processes to finish
            self.new_device_names: queue.Queue = queue.Queue()
            self.device_setter_thread = threading.Thread(target=add_logger(
                self.logger)(self._device_setter_worker),
                                                         name="DeviceSetter")
            self.device_setter_thread.start()
        except Exception as e:
            self.logger.exception(e)
            self.shutdown()
Ejemplo n.º 5
0
    def __init__(self, config: dict, model: torch.nn.Module, optimizer_state: bytes = b""):
        self.logger = logging.getLogger(__name__)
        self.logger.info("started")
        self.shutdown_event = threading.Event()
        self.idle = False

        self.common_model = model
        self.model = copy.deepcopy(model)
        self.optimizer_state = optimizer_state
        self.training_settings_lock = threading.Lock()
        # self.devices = [torch.device("cpu")]
        self.devices = []
        self.base_device = "cpu"

        training_transform = Compose(
            *[
                get_transform(name, **kwargs)
                for name, kwargs in config[TRAINING].get(TRANSFORMS, {"Normalize": {"apply_to": [0]}}).items()
            ]
        )
        validation_transform = Compose(
            *[
                get_transform(name, **kwargs)
                for name, kwargs in config[VALIDATION].get(TRANSFORMS, {"Normalize": {"apply_to": [0]}}).items()
            ]
        )

        self.datasets = {
            TRAINING: DynamicDataset(transform=training_transform),
            VALIDATION: DynamicDataset(transform=validation_transform),
        }
        self.update_loader = {TRAINING: True, VALIDATION: True}
        self.loader_kwargs = {
            TRAINING: {"dataset": self.datasets[TRAINING]},
            VALIDATION: {"dataset": self.datasets[VALIDATION]},
        }

        for key, default in self.trainer_defaults.items():
            if key not in config[TRAINING]:
                config[TRAINING][key] = default

        self.config = config

        self._pause_event = threading.Event()
        self._pause_event.set()
        self.update_trainer_event = threading.Event()
        self.update_trainer_event.set()

        self.trainer = TikTrainer.build(
            model=self.model,
            break_events=[self.shutdown_event, self._pause_event, self.update_trainer_event],
            **self.create_trainer_config(),
        )
        log_dir = self.config.get(LOGGING, {}).get(DIRECTORY, "")
        if os.path.exists(log_dir):
            log_dir = os.path.join(log_dir, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
            os.makedirs(log_dir, exist_ok=True)
            self.trainer.build_logger(
                TensorboardLogger,
                log_directory=log_dir,
                log_scalars_every=(1, "iteration"),
                log_images_every=(1, "epoch"),
            )
        self.trainer.register_callback(self.end_of_training_iteration, trigger="end_of_training_iteration")
        self.trainer.register_callback(self.end_of_validation_iteration, trigger="end_of_validation_iteration")
        self.trainer._iteration_count = self.config[TRAINING].get(NUM_ITERATIONS_DONE, 0)

        if self.optimizer_state:
            optimizer = self.create_optimizer(self.optimizer_state)
            if optimizer is not None:
                self.trainer.build_optimizer(optimizer)

        self.training_thread = threading.Thread(target=add_logger(self.logger)(self._training_worker), name="Training")
        self.training_thread.start()