Exemplo n.º 1
0
class ResnetSingleImageApplier(SingleImageInferenceBase):
    @property
    def classification_tags_key(self):
        return config_lib.classification_tags_key()

    @property
    def classification_tags_to_idx_key(self):
        return config_lib.classification_tags_to_idx_key()

    @property
    def train_classes_key(self):
        return config_lib.train_classes_key()

    @property
    def class_title_to_idx_key(self):
        return config_lib.class_to_idx_config_key()

    def _model_out_tags(self):
        temp_collection = TagMetaCollection.from_json(
            self.train_config[self.classification_tags_key])
        res_collection = TagMetaCollection([
            TagMeta(x.name, TagValueType.ANY_NUMBER) for x in temp_collection
        ])
        return res_collection

    def _load_train_config(
            self):  # @TODO: partly copypasted from SingleImageInferenceBase
        self._load_raw_model_config_json()

        self.classification_tags = self._model_out_tags()
        logger.info('Read model out tags',
                    extra={'tags': self.classification_tags.to_json()})
        self.classification_tags_to_idx = self.train_config[
            self.classification_tags_to_idx_key]
        logger.info('Read model internal tags mapping',
                    extra={'tags_mapping': self.classification_tags_to_idx})

        self._model_out_meta = ProjectMeta(obj_classes=ObjClassCollection(),
                                           tag_metas=self.classification_tags)

        self.idx_to_classification_tags = {
            v: k
            for k, v in self.classification_tags_to_idx.items()
        }
        self._determine_model_input_size()

    def _validate_model_config(self, config):
        JsonConfigValidator().validate_inference_cfg(config)

    @staticmethod
    def get_default_config():
        return {GPU_DEVICE: 0}

    def _construct_and_fill_model(self):
        super()._construct_and_fill_model()
        device_ids = sly.env.remap_gpu_devices([self._config[GPU_DEVICE]])
        num_layers = determine_resnet_model_configuration(
            TaskPaths.MODEL_CONFIG_PATH)
        self.model = create_model(
            num_layers=num_layers,
            n_cls=(max(self.classification_tags_to_idx.values()) + 1),
            device_ids=device_ids)

        self.model = WeightsRW(TaskPaths.MODEL_DIR).load_strictly(self.model)
        self.model.eval()
        logger.info('Weights are loaded.')

    def inference(self, img, ann):
        output = infer_on_img(img, self.input_size, self.model)
        tag_id = np.argmax(output)
        score = output[tag_id]
        tag_name = self.idx_to_classification_tags[tag_id]
        tag = Tag(self.classification_tags.get(tag_name),
                  round(float(score), 4))
        tags = TagCollection([tag])
        return Annotation(ann.img_size, img_tags=tags)
Exemplo n.º 2
0
def create_model_for_inference(n_cls, device_ids, model_dir):
    model = create_model(n_cls, device_ids)
    model = WeightsRW(model_dir).load_strictly(model)
    model.eval()
    return model