Ejemplo n.º 1
0
def trainSVMHistory(configFilename, paramsFilename, outputHistoryFilename,
                    className):
    config = yaml.load(open(configFilename).read())
    params = yaml.load(open(paramsFilename).read())['model']

    if params.pop('classifier') != 'svm':
        raise Exception('Can only use this script on SVM config parameters.')

    preproc = params.pop('preprocessing')

    ds = DataSet()
    ds.load(
        join(
            split(configFilename)[0],  # base dir
            config['datasetsDirectory'],  # datasets dir
            '%s-%s.db' % (config['className'], preproc)))  # dataset name

    gt = GroundTruth.fromFile(config['groundtruth'])

    if className:
        gt.className = className

    # add 'highlevel.' in front of the descriptor, this is what will appear in the final Essentia sigfile
    gt.className = 'highlevel.' + gt.className

    # do the whole training
    h = trainSVM(ds, gt, **params)

    h.save(outputHistoryFilename)
def train_svm_history(project, params, output_file_path):
    params_model = params["model"]
    if params_model.get("classifier") != "svm":
        raise GaiaWrapperException(
            "Can only use this script on SVM config parameters.")

    ds = DataSet()
    ds.load(
        os.path.join(
            project["datasetsDirectory"], "%s-%s.db" %
            (project["className"], params_model["preprocessing"])))

    gt = GroundTruth.fromFile(project["groundtruth"])
    gt.className = "highlevel." + project["className"]

    history = train_svm(
        ds,
        gt,
        type=params_model["type"],
        kernel=params_model["kernel"],
        C=params_model["C"],
        gamma=params_model["gamma"])  # doing the whole training
    if isinstance(output_file_path, unicode):
        output_file_path = output_file_path.encode("utf-8")
    history.save(output_file_path)
def train_svm_history(project, params, output_file_path):
    params_model = params["model"]
    if params_model.pop("classifier") != "svm":
        raise GaiaWrapperException("Can only use this script on SVM config parameters.")

    ds = DataSet()
    ds.load(os.path.join(
        project["datasetsDirectory"],
        "%s-%s.db" % (project["className"], params_model.pop("preprocessing"))
    ))

    gt = GroundTruth.fromFile(project["groundtruth"])
    gt.className = "highlevel." + project["className"]

    history = train_svm(ds, gt, **params_model)  # doing the whole training
    history.save(output_file_path)
Ejemplo n.º 4
0
    def loadGroundTruth(self, name=None):
        gttypes = self._config['groundTruth'].keys()

        if name is None:
            name = gttypes[0]
            if len(gttypes) > 1:
                print 'WARNING: more than 1 GroundTruth file, selecting default "%s" (out of %s)' % (
                    name, gttypes)
        else:
            if name not in gttypes:
                print 'WARNING: invalid ground truth: "%s", selecting default one instead: "%s" (out of %s)' % (
                    name, gttypes[0], gttypes)
                name = gttypes[0]

        self._groundTruthFile = self.groundTruthFilePath(name)
        self.groundTruth = GroundTruth.fromFile(self._groundTruthFile)
Ejemplo n.º 5
0
    def loadGroundTruth(self, name=None):
        gttypes = self._config["groundTruth"].keys()

        if name is None:
            name = gttypes[0]
            if len(gttypes) > 1:
                print 'WARNING: more than 1 GroundTruth file, selecting default "%s" (out of %s)' % (name, gttypes)
        else:
            if name not in gttypes:
                print 'WARNING: invalid ground truth: "%s", selecting default one instead: "%s" (out of %s)' % (
                    name,
                    gttypes[0],
                    gttypes,
                )
                name = gttypes[0]

        self._groundTruthFile = self.groundTruthFilePath(name)
        self.groundTruth = GroundTruth.fromFile(self._groundTruthFile)
def train_svm_history(project, params, output_file_path):
    params_model = params["model"]
    if params_model.get("classifier") != "svm":
        raise GaiaWrapperException("Can only use this script on SVM config parameters.")

    ds = DataSet()
    ds.load(os.path.join(
        project["datasetsDirectory"],
        "%s-%s.db" % (project["className"], params_model["preprocessing"])
    ))

    gt = GroundTruth.fromFile(project["groundtruth"])
    gt.className = "highlevel." + project["className"]

    history = train_svm(ds, gt, type=params_model["type"], kernel=params_model["kernel"],
                        C=params_model["C"], gamma=params_model["gamma"])  # doing the whole training
    if isinstance(output_file_path, unicode):
        output_file_path = output_file_path.encode("utf-8")
    history.save(output_file_path)