示例#1
0
    def _load_version(cls, state, version):
        """
        A function to load a previously saved ImageClassifier
        instance.
        """
        _tkutl._model_version_check(version,
                                    cls._PYTHON_IMAGE_CLASSIFIER_VERSION)
        from turicreate.toolkits.classifier.logistic_classifier import LogisticClassifier
        state['classifier'] = LogisticClassifier(state['classifier'])
        state['classes'] = state['classifier'].classes

        # Correct models saved with a previous typo
        if state['model'] == "VisionFeaturePrint_Screen":
            state['model'] = "VisionFeaturePrint_Scene"

        # Load pre-trained model & feature extractor
        model_name = state['model']
        if model_name == "VisionFeaturePrint_Scene" and _mac_ver() < (10, 14):
            raise ToolkitError(
                "Can not load model on this operating system. This model uses VisionFeaturePrint_Scene, "
                "which is only supported on macOS 10.14 and higher.")
        state[
            'feature_extractor'] = _image_feature_extractor._create_feature_extractor(
                model_name)
        state['input_image_shape'] = tuple(
            [int(i) for i in state['input_image_shape']])
        return ImageClassifier(state)
    def _load_version(cls, state, version):
        _tkutl._model_version_check(version, cls._PYTHON_ACTIVITY_CLASSIFIER_VERSION)

        data_seq_len = state['prediction_window'] * state['_predictions_in_chunk']
        data = {'data': (state['_recalibrated_batch_size'], data_seq_len, len(state['features']))}
        labels = [
            ('target', (state['_recalibrated_batch_size'], state['_predictions_in_chunk'], 1)),
            ('weights', (state['_recalibrated_batch_size'], state['_predictions_in_chunk'], 1))
        ]

        from ._model_architecture import _define_model
        import mxnet as _mx
        context = _mxnet_utils.get_mxnet_context(max_devices=state['num_sessions'])
        _, _pred_model = _define_model(state['features'], state['_target_id_map'], 
                                       state['prediction_window'],
                                       state['_predictions_in_chunk'], context)

        batch_size = state['batch_size']
        preds_in_chunk = state['_predictions_in_chunk']
        win = state['prediction_window'] * preds_in_chunk
        num_features = len(state['features'])
        data_shapes = [('data', (batch_size, win, num_features))]
        target_shape= (batch_size, preds_in_chunk, 1)

        _pred_model.bind(data_shapes=data_shapes, label_shapes=None,
                         for_training=False)
        arg_params = _mxnet_utils.params_from_dict(state['_pred_model']['arg_params'])
        aux_params = _mxnet_utils.params_from_dict(state['_pred_model']['aux_params'])
        _pred_model.init_params(arg_params=arg_params, aux_params=aux_params)
        state['_pred_model'] = _pred_model

        return ActivityClassifier(state)
示例#3
0
    def _load_version(cls, state, version):
        """
        A function to load a previously saved ImageClassifier
        instance.

        Parameters
        ----------
        unpickler : GLUnpickler
            A GLUnpickler file handler.

        version : int
            Version number maintained by the class writer.
        """
        _tkutl._model_version_check(version,
                                    cls._PYTHON_IMAGE_SIMILARITY_VERSION)
        from turicreate.toolkits.nearest_neighbors import NearestNeighborsModel

        state["similarity_model"] = NearestNeighborsModel(
            state["similarity_model"])

        # Correct models saved with a previous typo
        if state["model"] == "VisionFeaturePrint_Screen":
            state["model"] = "VisionFeaturePrint_Scene"

        if state["model"] == "VisionFeaturePrint_Scene" and _mac_ver() < (10,
                                                                          14):
            raise _ToolkitError(
                "Can not load model on this operating system. This model uses VisionFeaturePrint_Scene, "
                "which is only supported on macOS 10.14 and higher.")
        state[
            "feature_extractor"] = _image_feature_extractor._create_feature_extractor(
                state["model"])
        state["input_image_shape"] = tuple(
            [int(i) for i in state["input_image_shape"]])
        return ImageSimilarityModel(state)
示例#4
0
    def _load_version(cls, state, version):
        """
        A function to load a previously saved ImageClassifier
        instance.

        Parameters
        ----------
        unpickler : GLUnpickler
            A GLUnpickler file handler.

        version : int
            Version number maintained by the class writer.
        """
        _tkutl._model_version_check(version,
                                    cls._PYTHON_IMAGE_SIMILARITY_VERSION)
        from turicreate.toolkits.nearest_neighbors import NearestNeighborsModel
        state['similarity_model'] = NearestNeighborsModel(
            state['similarity_model'])
        # Load pre-trained model & feature extractor
        ptModel = _pre_trained_models.MODELS[state['model']]()
        feature_extractor = _image_feature_extractor.MXFeatureExtractor(
            ptModel)
        state['feature_extractor'] = feature_extractor
        state['input_image_shape'] = tuple(
            [int(i) for i in state['input_image_shape']])
        return ImageSimilarityModel(state)
示例#5
0
 def _load_version(cls, state, version):
     _tkutl._model_version_check(version, 1)
     from ._model_architecture import Model as _Model
     net = _Model(num_classes = len(state['classes']), prefix = 'drawing_')
     ctx = _mxnet_utils.get_mxnet_context(max_devices=state['batch_size'])
     net_params = net.collect_params()
     _mxnet_utils.load_net_params_from_state(
         net_params, state['_model'], ctx=ctx 
         )
     state['_model'] = net
     return DrawingClassifier(state)
示例#6
0
    def _load_version(cls, state, version):
        from ._model import Transformer as _Transformer
        _tkutl._model_version_check(version, cls._PYTHON_STYLE_TRANSFER_VERSION)

        net = _Transformer(state['num_styles'], state['batch_size'])
        ctx = _mxnet_utils.get_mxnet_context(max_devices=state['batch_size'])

        net_params = net.collect_params()
        _mxnet_utils.load_net_params_from_state(net_params, state['_model'], ctx=ctx)
        state['_model'] = net
        state['input_image_shape'] = tuple([int(i) for i in state['input_image_shape']])
        return StyleTransfer(state)
示例#7
0
    def _load_version(cls, state, version):
        """
        A function to load a previously saved ImageClassifier
        instance.
        """
        _tkutl._model_version_check(version, cls._PYTHON_IMAGE_CLASSIFIER_VERSION)
        from turicreate.toolkits.classifier.logistic_classifier import LogisticClassifier
        state['classifier'] = LogisticClassifier(state['classifier'])
        state['classes'] = state['classifier'].classes

        # Load pre-trained model & feature extractor
        ptModel = _pre_trained_models.MODELS[state['model']]()
        feature_extractor = _image_feature_extractor.MXFeatureExtractor(ptModel)
        state['feature_extractor'] = feature_extractor
        state['input_image_shape'] = tuple([int(i) for i in state['input_image_shape']])
        return ImageClassifier(state)
示例#8
0
    def _load_version(cls, state, version):
        _tkutl._model_version_check(version, cls._PYTHON_OBJECT_DETECTOR_VERSION)
        from ._model import tiny_darknet as _tiny_darknet

        num_anchors = len(state['anchors'])
        num_classes = state['num_classes']
        output_size = (num_classes + 5) * num_anchors

        net = _tiny_darknet(output_size=output_size)
        ctx = _mxnet_utils.get_mxnet_context(max_devices=state['batch_size'])

        net_params = net.collect_params()
        _mxnet_utils.load_net_params_from_state(net_params, state['_model'], ctx=ctx)
        state['_model'] = net
        state['input_image_shape'] = tuple([int(i) for i in state['input_image_shape']])
        state['_grid_shape'] = tuple([int(i) for i in state['_grid_shape']])
        return ObjectDetector(state)
示例#9
0
 def _load_version(cls, state, version):
     _tkutl._model_version_check(version,
         cls._PYTHON_DRAWING_CLASSIFIER_VERSION)
     from ._model_architecture import Model as _Model
     from .._mxnet import _mxnet_utils
     net = _Model(num_classes = len(state['classes']), prefix = 'drawing_')
     ctx = _mxnet_utils.get_mxnet_context()
     net_params = net.collect_params()
     _mxnet_utils.load_net_params_from_state(
         net_params, state['_model'], ctx=ctx
         )
     state['_model'] = net
     # For a model trained on integer classes, when saved and loaded back,
     # the classes are loaded as floats. The following if statement casts
     # the loaded "float" classes back to int.
     if len(state['classes']) > 0 and isinstance(state['classes'][0], float):
         state['classes'] = list(map(int, state['classes']))
     return DrawingClassifier(state)