示例#1
0
    def _construct_and_fill_model(self):
        # Progress reporting to show a progress bar in the UI.
        model_build_progress = sly.Progress('Building model:', 1)

        # Check the class name --> index mapping to infer the number of model output dimensions.
        num_classes = max(self.class_title_to_idx.values()) + 1

        # Initialize the model.
        model = PyTorchSegmentation(num_classes=num_classes)
        sly.logger.info('Model has been instantiated.')

        # Load model weights appropriate for the given training mode.
        weights_rw = WeightsRW(sly.TaskPaths.MODEL_DIR)
        weights_init_type = self.config[WEIGHTS_INIT_TYPE]
        if weights_init_type == TRANSFER_LEARNING:
            # For transfer learning, do not attempt to load the weights for the model head. The existing snapshot may
            # have been trained on a different dataset, even on a different set of classes, and is in general not
            # compatible with the current model even in terms of dimensions. The head of the model will be initialized
            # randomly.
            self._model = weights_rw.load_for_transfer_learning(
                model, ignore_matching_layers=['_head'], logger=sly.logger)
        elif weights_init_type == CONTINUE_TRAINING:
            # Continuing training from an older snapshot requires full compatibility between the two models, including
            # class index mapping. Hence the snapshot weights must exactly match the structure of our model instance.
            self._model = weights_rw.load_strictly(model)

        # Model weights have been loaded, move them over to the GPU.
        self._model.cuda()

        # Advance the progress bar and log a progress message.
        sly.logger.info('Weights have been loaded.',
                        extra={WEIGHTS_INIT_TYPE: weights_init_type})
        model_build_progress.iter_done_report()
示例#2
0
    def _construct_and_fill_model(self):
        # TODO: Move it progress to base class
        progress_dummy = sly.Progress('Building model:', 1)
        progress_dummy.iter_done_report()
        self.model = create_model(
            n_cls=(max(self.class_title_to_idx.values()) + 1),
            device_ids=self.device_ids)

        if sly.fs.dir_empty(sly.TaskPaths.MODEL_DIR):
            sly.logger.info('Weights will not be inited.')
            # @TODO: add random init (m.weight.data.normal_(0, math.sqrt(2. / n))
        else:
            wi_type = self.config['weights_init_type']
            ewit = {'weights_init_type': wi_type}
            sly.logger.info('Weights will be inited from given model.',
                            extra=ewit)

            weights_rw = WeightsRW(sly.TaskPaths.MODEL_DIR)
            if wi_type == TRANSFER_LEARNING:
                self.model = weights_rw.load_for_transfer_learning(
                    self.model,
                    ignore_matching_layers=['last_conv'],
                    logger=logger)
            elif wi_type == CONTINUE_TRAINING:
                self.model = weights_rw.load_strictly(self.model)

            sly.logger.info('Weights are loaded.', extra=ewit)
示例#3
0
    def _construct_and_fill_model(self):
        progress_dummy = sly.Progress('Building model:', 1)
        progress_dummy.iter_done_report()

        self.model = create_model(self.num_layers,
                                  n_cls=len(self.classification_tags_sorted), device_ids=self.device_ids)

        if sly.fs.dir_empty(sly.TaskPaths.MODEL_DIR):
            logger.info('Weights will not be inited.')
            # @TODO: add random init (m.weight.data.normal_(0, math.sqrt(2. / n))
        else:
            wi_type = self.config['weights_init_type']
            ewit = {'weights_init_type': wi_type}
            logger.info('Weights will be inited from given model.', extra=ewit)

            weights_rw = WeightsRW(sly.TaskPaths.MODEL_DIR)
            if wi_type == TRANSFER_LEARNING:
                self.model = weights_rw.load_for_transfer_learning(self.model, ignore_matching_layers=['fc'],
                                                                   logger=logger)
            elif wi_type == CONTINUE_TRAINING:
                self.model = weights_rw.load_strictly(self.model)

            logger.info('Weights are loaded.', extra=ewit)