コード例 #1
0
ファイル: fisher.py プロジェクト: pgsrv/fisher_vectors
def do(task, src_cfg, **kwargs):
    ip_type = kwargs.get('ip_type', 'dense5.track15mbh')
    suffix = kwargs.get('suffix', '')
    nr_clusters = kwargs.get('nr_clusters')
    dataset = Dataset(
        src_cfg, ip_type=ip_type, suffix=suffix, nr_clusters=nr_clusters)

    per_slice = kwargs.get('per_slice', False)

    # Select the task based on opt_type argument.
    if task == 'compute':
        pass
    elif task == 'check':
        from fisher_vectors import data
        data.check_given_dataset(dataset)
    elif task == 'merge':
        if per_slice:
            from fisher_vectors.per_slice import features
            class_idx = kwargs.get('class_idx')
            features.merge_class_given_dataset(dataset, class_idx)
        else:
            from fisher_vectors import data
            data.merge_given_dataset(dataset)
    elif task == 'remove':
        pass
    elif task == 'evaluate':
        from fisher_vectors.scripts.evaluate import evaluate_given_dataset
        evaluate_given_dataset(dataset, **kwargs)
        #with open(os.path.join(
        #    ROOT_PATH, dataset.DATASET, 'results', RESULTS_FN), 'a') as ff:
        #    ff.write('%s %2.3f\n' % (str(model), 100 * model.score()))
    else:
        raise Exception("Task is not defined.")
        usage()
        sys.exit(1)
コード例 #2
0
    def test_merge_class_given_dataset(self):
        out_folder = "/tmp/"
        out_fn = os.path.join(out_folder, "train_class_0.dat")
        labels_fn = os.path.join(out_folder, "labels_train_class_0.info")

        nr_pos = 30
        nr_neg = 60

        merge_class_given_dataset(
            self.dataset, 0, out_folder=out_folder, nr_positive_samples=nr_pos, nr_negative_samples=nr_neg
        )

        # Load sufficient statistics.
        sstats = np.fromfile(out_fn, dtype=np.float32).reshape((-1, self.sstats_len))
        labels = pickle.load(open(labels_fn, "rb"))

        # Check if the number is accurate.
        assert sstats.shape[0] == len(labels), "Size mismatch."
コード例 #3
0
    def test_merge_class_given_dataset(self):
        out_folder = '/tmp/'
        out_fn = os.path.join(out_folder, 'train_class_0.dat')
        labels_fn = os.path.join(out_folder, 'labels_train_class_0.info')

        nr_pos = 30
        nr_neg = 60

        merge_class_given_dataset(self.dataset,
                                  0,
                                  out_folder=out_folder,
                                  nr_positive_samples=nr_pos,
                                  nr_negative_samples=nr_neg)

        # Load sufficient statistics.
        sstats = np.fromfile(out_fn, dtype=np.float32).reshape(
            (-1, self.sstats_len))
        labels = pickle.load(open(labels_fn, 'rb'))

        # Check if the number is accurate.
        assert sstats.shape[0] == len(labels), "Size mismatch."