Ejemplo n.º 1
0
def create_selected(selected_model_name,
                    dataset,
                    target,
                    features,
                    validation_set='auto',
                    verbose=True):

    # Create the model
    model = create(dataset,
                   target,
                   selected_model_name,
                   features=features,
                   validation_set=validation_set,
                   verbose=verbose)

    # Return the model
    if selected_model_name == 'boosted_trees_regression':
        return _turicreate.boosted_trees_regression.BoostedTreesRegression(\
            model.__proxy__)
    elif selected_model_name == 'random_forest_regression':
        return _turicreate.random_forest_regression.RandomForestRegression(\
            model.__proxy__)
    elif selected_model_name == 'decision_tree_regression':
        return _turicreate.decision_tree_classifier.DecisionTreeRegression(\
          model.__proxy__)
    elif selected_model_name == 'regression_linear_regression':
        return _turicreate.linear_regression.LinearRegression(\
            model.__proxy__)
    elif selected_model_name == 'boosted_trees_classifier':
        return _turicreate.boosted_trees_classifier.BoostedTreesClassifier(\
          model.__proxy__)
    elif selected_model_name == 'random_forest_classifier':
        return _turicreate.random_forest_classifier.RandomForestClassifier(\
          model.__proxy__)
    elif selected_model_name == 'decision_tree_classifier':
        return _turicreate.decision_tree_classifier.DecisionTreeClassifier(\
          model.__proxy__)
    elif selected_model_name == 'classifier_logistic_regression':
        return _turicreate.logistic_classifier.LogisticClassifier(\
          model.__proxy__)
    elif selected_model_name == 'classifier_svm':
        return _turicreate.svm_classifier.SVMClassifier(model.__proxy__)
    else:
        raise ToolkitError("Internal error: Incorrect model returned.")
Ejemplo n.º 2
0
def _validate_row_label(label, column_type_map):
    """
    Validate a row label column.

    Parameters
    ----------
    label : str
        Name of the row label column.

    column_type_map : dict[str, type]
        Dictionary mapping the name of each column in an SFrame to the type of
        the values in the column.
    """
    if not isinstance(label, str):
        raise TypeError("The row label column name must be a string.")

    if not label in column_type_map.keys():
        raise ToolkitError("Row label column not found in the dataset.")

    if not column_type_map[label] in (str, int):
        raise TypeError("Row labels must be integers or strings.")
Ejemplo n.º 3
0
    def summary(self, output=None):
        """
        Print a summary of the model. The summary includes a description of
        training data, options, hyper-parameters, and statistics measured
        during model creation.

        Parameters
        ----------
        output : str, None
            The type of summary to return.

            - None or 'stdout' : print directly to stdout.

            - 'str' : string of summary

            - 'dict' : a dict with 'sections' and 'section_titles' ordered
              lists. The entries in the 'sections' list are tuples of the form
              ('label', 'value').

        Examples
        --------
        >>> m.summary()
        """
        if output is None or output == "stdout":
            try:
                print(self.__repr__())
            except:
                return self.__class__.__name__
        elif output == "str":
            return self.__repr__()
        elif output == "dict":
            return _toolkit_serialize_summary_struct(
                self, *self._get_summary_struct())
        else:
            raise ToolkitError("Unsupported argument " + str(output) +
                               ' for "summary" parameter.')
def shuffle_sframe(sf, random_seed=None, temp_shuffle_col="shuffle_col"):
    """
    Create a copy of the SFrame where the rows have been shuffled randomly.

    Parameters
    ----------
    sf: SFrame
        A Non empty SFrame.
    random_seed: int, optional
        Random seed to use for the randomization. If provided, each call
        to this method will produce an identical result.
    temp_shuffle_col: str, optional
        Change only if you use the same column name.

    Returns
    -------
    SFrame
        A randomly shuffled SFrame.

    Examples
    --------
        >>> url = 'https://static.turi.com/datasets/xgboost/mushroom.csv'
        >>> sf = tc.SFrame.read_csv(url)
        >>> shuffle_sframe(sf)
    """

    if temp_shuffle_col in sf.column_names():
        raise ToolkitError(
            'The SFrame already contains column named {0}. '
            'Please enter set another value to temp_shuffle_col'.format(
                temp_shuffle_col))
    shuffled_sframe = sf.copy()
    shuffled_sframe[temp_shuffle_col] = tc.SArray.random_integers(
        sf.num_rows(), random_seed)
    return shuffled_sframe.sort(temp_shuffle_col).remove_column(
        temp_shuffle_col)
Ejemplo n.º 5
0
def load_model(location):
    """
    Load any Turi Create model that was previously saved.

    This function assumes the model (can be any model) was previously saved in
    Turi Create model format with model.save(filename).

    Parameters
    ----------
    location : string
        Location of the model to load. Can be a local path or a remote URL.
        Because models are saved as directories, there is no file extension.

    Examples
    ----------
    >>> model.save('my_model_file')
    >>> loaded_model = tc.load_model('my_model_file')
    """

    # Check if the location is a dir_archive, if not, use glunpickler to load
    # as pure python model
    # If the location is a http location, skip the check, and directly proceed
    # to load model as dir_archive. This is because
    # 1) exists() does not work with http protocol, and
    # 2) GLUnpickler does not support http
    protocol = file_util.get_protocol(location)
    dir_archive_exists = False
    if protocol == '':
        model_path = file_util.expand_full_path(location)
        dir_archive_exists = file_util.exists(
            os.path.join(model_path, 'dir_archive.ini'))
    else:
        model_path = location
        if protocol in ['http', 'https']:
            dir_archive_exists = True
        else:
            import posixpath
            dir_archive_exists = file_util.exists(
                posixpath.join(model_path, 'dir_archive.ini'))
    if not dir_archive_exists:
        raise IOError("Directory %s does not exist" % location)

    _internal_url = _make_internal_url(location)
    saved_state = glconnect.get_unity().load_model(_internal_url)
    # The archive version could be both bytes/unicode
    key = u'archive_version'
    archive_version = saved_state[key] if key in saved_state else saved_state[
        key.encode()]
    if archive_version < 0:
        raise ToolkitError("File does not appear to be a Turi Create model.")
    elif archive_version > 1:
        raise ToolkitError(
            "Unable to load model.\n\n"
            "This model looks to have been saved with a future version of Turi Create.\n"
            "Please upgrade Turi Create before attempting to load this model file."
        )
    elif archive_version == 1:
        cls = MODEL_NAME_MAP[saved_state['model_name']]
        if 'model' in saved_state:
            # this is a native model
            return cls(saved_state['model'])
        else:
            # this is a CustomModel
            model_data = saved_state['side_data']
            model_version = model_data['model_version']
            del model_data['model_version']
            return cls._load_version(model_data, model_version)
    else:
        # very legacy model format. Attempt pickle loading
        import sys
        sys.stderr.write(
            "This model was saved in a legacy model format. Compatibility cannot be guaranteed in future versions.\n"
        )
        if _six.PY3:
            raise ToolkitError(
                "Unable to load legacy model in Python 3.\n\n"
                "To migrate a model, try loading it using Turi Create 4.0 or\n"
                "later in Python 2 and then re-save it. The re-saved model should\n"
                "work in Python 3.")

        if 'graphlab' not in sys.modules:
            sys.modules['graphlab'] = sys.modules['turicreate']
            # backward compatibility. Otherwise old pickles will not load
            sys.modules["turicreate_util"] = sys.modules['turicreate.util']
            sys.modules["graphlab_util"] = sys.modules['turicreate.util']

            # More backwards compatibility with the turicreate namespace code.
            for k, v in list(sys.modules.items()):
                if 'turicreate' in k:
                    sys.modules[k.replace('turicreate', 'graphlab')] = v
        #legacy loader
        import pickle
        model_wrapper = pickle.loads(saved_state[b'model_wrapper'])
        return model_wrapper(saved_state[b'model_base'])
Ejemplo n.º 6
0
def load_model(location):
    """
    Load any Turi Create model that was previously saved.

    This function assumes the model (can be any model) was previously saved in
    Turi Create model format with model.save(filename).

    Parameters
    ----------
    location : string
        Location of the model to load. Can be a local path or a remote URL.
        Because models are saved as directories, there is no file extension.

    Examples
    ----------
    >>> model.save('my_model_file')
    >>> loaded_model = tc.load_model('my_model_file')
    """

    # Check if the location is a dir_archive, if not, use glunpickler to load
    # as pure python model
    # If the location is a http location, skip the check, and directly proceed
    # to load model as dir_archive. This is because
    # 1) exists() does not work with http protocol, and
    # 2) GLUnpickler does not support http
    protocol = file_util.get_protocol(location)
    dir_archive_exists = False
    if protocol == "":
        model_path = file_util.expand_full_path(location)
        dir_archive_exists = file_util.exists(
            os.path.join(model_path, "dir_archive.ini"))
    else:
        model_path = location
        if protocol in ["http", "https", "s3"]:
            dir_archive_exists = True
        else:
            import posixpath

            dir_archive_exists = file_util.exists(
                posixpath.join(model_path, "dir_archive.ini"))
    if not dir_archive_exists:
        raise IOError("Directory %s does not exist" % location)

    _internal_url = _make_internal_url(location)
    saved_state = glconnect.get_unity().load_model(_internal_url)
    saved_state = _wrap_function_return(saved_state)
    # The archive version could be both bytes/unicode
    key = u"archive_version"
    archive_version = (saved_state[key]
                       if key in saved_state else saved_state[key.encode()])
    if archive_version < 0:
        raise ToolkitError("File does not appear to be a Turi Create model.")
    elif archive_version > 1:
        raise ToolkitError(
            "Unable to load model.\n\n"
            "This model looks to have been saved with a future version of Turi Create.\n"
            "Please upgrade Turi Create before attempting to load this model file."
        )
    elif archive_version == 1:
        name = saved_state["model_name"]
        if name in MODEL_NAME_MAP:
            cls = MODEL_NAME_MAP[name]
            if "model" in saved_state:
                if name in [
                        "activity_classifier",
                        "object_detector",
                        "style_transfer",
                        "drawing_classifier",
                ]:
                    import turicreate.toolkits.libtctensorflow
                # this is a native model
                return cls(saved_state["model"])
            else:
                # this is a CustomModel
                model_data = saved_state["side_data"]
                model_version = model_data["model_version"]
                del model_data["model_version"]

                if name == "activity_classifier":
                    import turicreate.toolkits.libtctensorflow

                    model = _extensions.activity_classifier()
                    model.import_from_custom_model(model_data, model_version)
                    return cls(model)

                if name == "object_detector":
                    import turicreate.toolkits.libtctensorflow

                    model = _extensions.object_detector()
                    model.import_from_custom_model(model_data, model_version)
                    return cls(model)

                if name == "style_transfer":
                    import turicreate.toolkits.libtctensorflow

                    model = _extensions.style_transfer()
                    model.import_from_custom_model(model_data, model_version)
                    return cls(model)

                if name == "drawing_classifier":
                    import turicreate.toolkits.libtctensorflow

                    model = _extensions.drawing_classifier()
                    model.import_from_custom_model(model_data, model_version)
                    return cls(model)

                if name == "one_shot_object_detector":
                    import turicreate.toolkits.libtctensorflow

                    od_cls = MODEL_NAME_MAP["object_detector"]
                    if "detector_model" in model_data["detector"]:
                        model_data["detector"] = od_cls(
                            model_data["detector"]["detector_model"])
                    else:
                        model = _extensions.object_detector()
                        model.import_from_custom_model(
                            model_data["detector"],
                            model_data["_detector_version"])
                        model_data["detector"] = od_cls(model)
                    return cls(model_data)

                return cls._load_version(model_data, model_version)

        elif hasattr(_extensions, name):
            return saved_state["model"]
        else:
            raise ToolkitError(
                "Unable to load model of name '%s'; model name not registered."
                % name)
    else:
        # very legacy model format. Attempt pickle loading
        import sys

        sys.stderr.write(
            "This model was saved in a legacy model format. Compatibility cannot be guaranteed in future versions.\n"
        )
        if _six.PY3:
            raise ToolkitError(
                "Unable to load legacy model in Python 3.\n\n"
                "To migrate a model, try loading it using Turi Create 4.0 or\n"
                "later in Python 2 and then re-save it. The re-saved model should\n"
                "work in Python 3.")

        if "graphlab" not in sys.modules:
            sys.modules["graphlab"] = sys.modules["turicreate"]
            # backward compatibility. Otherwise old pickles will not load
            sys.modules["turicreate_util"] = sys.modules["turicreate.util"]
            sys.modules["graphlab_util"] = sys.modules["turicreate.util"]

            # More backwards compatibility with the turicreate namespace code.
            for k, v in list(sys.modules.items()):
                if "turicreate" in k:
                    sys.modules[k.replace("turicreate", "graphlab")] = v
        # legacy loader
        import pickle

        model_wrapper = pickle.loads(saved_state[b"model_wrapper"])
        return model_wrapper(saved_state[b"model_base"])