예제 #1
0
 def _finalize_epoch(self, epoch):
     self._tracker.finalise(**self.tracker_options)
     if self._is_training:
         metrics = self._tracker.publish(epoch)
         self._checkpoint.save_best_models_under_current_metrics(self._model, metrics, self._tracker.metric_func)
         if self.wandb_log:
             Wandb.add_file(self._checkpoint.checkpoint_path)
         if self._tracker._stage == "train":
             log.info("Learning rate = %f" % self._model.learning_rate)
예제 #2
0
    def from_pretrained(model_tag,
                        download=True,
                        out_file=None,
                        weight_name="latest",
                        mock_dataset=True):
        # Convert inputs to registry format

        if PretainedRegistry.MODELS.get(model_tag) is not None:
            url = PretainedRegistry.MODELS.get(model_tag)
        else:
            raise Exception(
                "model_tag {} doesn't exist within available models. Here is the list of pre-trained models {}"
                .format(model_tag, PretainedRegistry.available_models()))

        checkpoint_name = model_tag + ".pt"
        out_file = os.path.join(CHECKPOINT_DIR, checkpoint_name)

        if download:
            download_file(url, out_file)

            weight_name = weight_name if weight_name is not None else "latest"

            checkpoint: ModelCheckpoint = ModelCheckpoint(
                CHECKPOINT_DIR,
                model_tag,
                weight_name if weight_name is not None else "latest",
                resume=False,
            )
            if mock_dataset:
                dataset = checkpoint.dataset_properties.copy()
                if PretainedRegistry.MOCK_USED_PROPERTIES.get(
                        model_tag) is not None:
                    for k, v in PretainedRegistry.MOCK_USED_PROPERTIES.get(
                            model_tag).items():
                        dataset[k] = v

            else:
                dataset = instantiate_dataset(checkpoint.data_config)

            model: BaseModel = checkpoint.create_model(dataset,
                                                       weight_name=weight_name)

            Wandb.set_urls_to_model(model, url)

            BaseDataset.set_transform(model, checkpoint.data_config)

            return model
예제 #3
0
    def _initialize_trainer(self):
        # Enable CUDNN BACKEND
        torch.backends.cudnn.enabled = self.enable_cudnn

        if not self.has_training:
            self._cfg.training = self._cfg
            resume = bool(self._cfg.checkpoint_dir)
        else:
            resume = bool(self._cfg.training.checkpoint_dir)

        # Get device
        if self._cfg.training.cuda > -1 and torch.cuda.is_available():
            device = "cuda"
            torch.cuda.set_device(self._cfg.training.cuda)
        else:
            device = "cpu"
        self._device = torch.device(device)
        log.info("DEVICE : {}".format(self._device))

        # Profiling
        if self.profiling:
            # Set the num_workers as torch.utils.bottleneck doesn't work well with it
            self._cfg.training.num_workers = 0

        # Start Wandb if public
        if self.wandb_log:
            Wandb.launch(self._cfg, self._cfg.wandb.public and self.wandb_log)

        # Checkpoint

        self._checkpoint: ModelCheckpoint = ModelCheckpoint(
            self._cfg.training.checkpoint_dir,
            self._cfg.model_name,
            self._cfg.training.weight_name,
            run_config=self._cfg,
            resume=resume,
        )

        # Create model and datasets
        if not self._checkpoint.is_empty:
            self._dataset: BaseDataset = instantiate_dataset(
                self._checkpoint.data_config)
            self._model: BaseModel = self._checkpoint.create_model(
                self._dataset, weight_name=self._cfg.training.weight_name)
        else:
            self._dataset: BaseDataset = instantiate_dataset(self._cfg.data)
            self._model: BaseModel = instantiate_model(
                copy.deepcopy(self._cfg), self._dataset)
            self._model.instantiate_optimizers(self._cfg, "cuda" in device)
            self._model.set_pretrained_weights()
            if not self._checkpoint.validate(self._dataset.used_properties):
                log.warning(
                    "The model will not be able to be used from pretrained weights without the corresponding dataset. Current properties are {}"
                    .format(self._dataset.used_properties))
        self._checkpoint.dataset_properties = self._dataset.used_properties

        log.info(self._model)

        self._model.log_optimizers()
        log.info(
            "Model size = %i",
            sum(param.numel() for param in self._model.parameters()
                if param.requires_grad))

        # Set dataloaders
        self._dataset.create_dataloaders(
            self._model,
            self._cfg.training.batch_size,
            self._cfg.training.shuffle,
            self._cfg.training.num_workers,
            self.precompute_multi_scale,
        )
        log.info(self._dataset)

        # Verify attributes in dataset
        self._model.verify_data(self._dataset.train_dataset[0])

        # Choose selection stage
        selection_stage = getattr(self._cfg, "selection_stage", "")
        self._checkpoint.selection_stage = self._dataset.resolve_saving_stage(
            selection_stage)
        self._tracker: BaseTracker = self._dataset.get_tracker(
            self.wandb_log, self.tensorboard_log)

        if self.wandb_log:
            Wandb.launch(self._cfg, not self._cfg.wandb.public
                         and self.wandb_log)

        # Run training / evaluation
        self._model = self._model.to(self._device)
        if self.has_visualization:
            self._visualizer = Visualizer(self._cfg.visualization,
                                          self._dataset.num_batches,
                                          self._dataset.batch_size,
                                          os.getcwd())