def _construct_and_fill_model(self): # Progress reporting to show a progress bar in the UI. model_build_progress = sly.Progress('Building model:', 1) # Check the class name --> index mapping to infer the number of model output dimensions. num_classes = max(self.class_title_to_idx.values()) + 1 # Initialize the model. model = PyTorchSegmentation(num_classes=num_classes) sly.logger.info('Model has been instantiated.') # Load model weights appropriate for the given training mode. weights_rw = WeightsRW(sly.TaskPaths.MODEL_DIR) weights_init_type = self.config[WEIGHTS_INIT_TYPE] if weights_init_type == TRANSFER_LEARNING: # For transfer learning, do not attempt to load the weights for the model head. The existing snapshot may # have been trained on a different dataset, even on a different set of classes, and is in general not # compatible with the current model even in terms of dimensions. The head of the model will be initialized # randomly. self._model = weights_rw.load_for_transfer_learning( model, ignore_matching_layers=['_head'], logger=sly.logger) elif weights_init_type == CONTINUE_TRAINING: # Continuing training from an older snapshot requires full compatibility between the two models, including # class index mapping. Hence the snapshot weights must exactly match the structure of our model instance. self._model = weights_rw.load_strictly(model) # Model weights have been loaded, move them over to the GPU. self._model.cuda() # Advance the progress bar and log a progress message. sly.logger.info('Weights have been loaded.', extra={WEIGHTS_INIT_TYPE: weights_init_type}) model_build_progress.iter_done_report()
def _construct_and_fill_model(self): # TODO: Move it progress to base class progress_dummy = sly.Progress('Building model:', 1) progress_dummy.iter_done_report() self.model = create_model( n_cls=(max(self.class_title_to_idx.values()) + 1), device_ids=self.device_ids) if sly.fs.dir_empty(sly.TaskPaths.MODEL_DIR): sly.logger.info('Weights will not be inited.') # @TODO: add random init (m.weight.data.normal_(0, math.sqrt(2. / n)) else: wi_type = self.config['weights_init_type'] ewit = {'weights_init_type': wi_type} sly.logger.info('Weights will be inited from given model.', extra=ewit) weights_rw = WeightsRW(sly.TaskPaths.MODEL_DIR) if wi_type == TRANSFER_LEARNING: self.model = weights_rw.load_for_transfer_learning( self.model, ignore_matching_layers=['last_conv'], logger=logger) elif wi_type == CONTINUE_TRAINING: self.model = weights_rw.load_strictly(self.model) sly.logger.info('Weights are loaded.', extra=ewit)
def _construct_and_fill_model(self): progress_dummy = sly.Progress('Building model:', 1) progress_dummy.iter_done_report() self.model = create_model(self.num_layers, n_cls=len(self.classification_tags_sorted), device_ids=self.device_ids) if sly.fs.dir_empty(sly.TaskPaths.MODEL_DIR): logger.info('Weights will not be inited.') # @TODO: add random init (m.weight.data.normal_(0, math.sqrt(2. / n)) else: wi_type = self.config['weights_init_type'] ewit = {'weights_init_type': wi_type} logger.info('Weights will be inited from given model.', extra=ewit) weights_rw = WeightsRW(sly.TaskPaths.MODEL_DIR) if wi_type == TRANSFER_LEARNING: self.model = weights_rw.load_for_transfer_learning(self.model, ignore_matching_layers=['fc'], logger=logger) elif wi_type == CONTINUE_TRAINING: self.model = weights_rw.load_strictly(self.model) logger.info('Weights are loaded.', extra=ewit)