コード例 #1
0
 def __init__(self, config_file, seeds_file, negative_seeds, similarity,
              confidence):
     self.curr_iteration = 0
     self.patterns = list()
     self.processed_tuples = list()
     self.candidate_tuples = defaultdict(list)
     self.config = Config(config_file, seeds_file, negative_seeds,
                          similarity, confidence)
コード例 #2
0
 def __init__(self, args):
     if args.num_cores == 0:
         self.num_cpus = multiprocessing.cpu_count()
     else:
         self.num_cpus = args.num_cores
     self.processed_tuples = list()
     self.candidate_tuples = defaultdict(
         list)  # 当字典里的key不存在但被查找时,返回的不是keyError而是一个默认空list
     self.config = Config(args.config_file, args.positive_seeds_file,
                          args.negative_seeds_file, args.similarity,
                          args.confidence)
コード例 #3
0
ファイル: breds-parallel.py プロジェクト: yespon/BREDS
 def __init__(self, config_file, seeds_file, negative_seeds, similarity,
              confidence, num_cores):
     if num_cores == 0:
         self.num_cpus = multiprocessing.cpu_count()
     else:
         self.num_cpus = num_cores
     self.processed_tuples = list()
     self.candidate_tuples = defaultdict(list)
     self.curr_iteration = 0
     self.patterns = list()
     self.patterns_index = dict()
     self.config = Config(config_file, seeds_file, negative_seeds,
                          similarity, confidence)
コード例 #4
0
ファイル: breds_inference.py プロジェクト: DAlkemade/BREDS
def gather_sizes_with_bootstrapping_patterns(cfg: Box, patterns, all_new_objects) -> DefaultDict[Tuple, list]:
    """Gather text, parse tuples and check if tuples include valid sizes."""
    visual_config = VisualConfig(cfg.path.vg_objects, cfg.path.vg_objects_anchors)
    config = Config(cfg, visual_config)
    tuples = generate_tuples(randomString(), config, names=all_new_objects)

    config.visual = cfg.parameters.visual_at_inference

    candidate_tuples = extract_tuples(config, patterns, tuples)
    filtered_tuples = filter_tuples(candidate_tuples, cfg.parameters.dev_threshold)

    for t in candidate_tuples.keys():
        logger.info(t.sentence)
        logger.info(f"{t.e1} {t.e2}")
        logger.info(t.confidence)
        logger.info("\n")

    return filtered_tuples
コード例 #5
0
ファイル: visual_propagation.py プロジェクト: DAlkemade/BREDS
def main():
    with open("config.yml", "r") as ymlfile:
        cfg = Box(yaml.safe_load(ymlfile))
        # cfg = Box(yaml.safe_load(ymlfile), default_box=True, default_box_attr=None)

    test_pairs, unseen_objects = comparison_dev_set(cfg)
    unseen_objects = [o.replace('_', " ") for o in unseen_objects]

    # TODO check whether the objects aren't in the bootstrapped objects
    visual_config = VisualConfig(cfg.path.vg_objects, cfg.path.vg_objects_anchors)
    config = Config(cfg, visual_config)

    visual_config = config.visual_config
    objects = list(visual_config.entity_to_synsets.keys())
    logger.info(f'Objects: {objects}')
    G = build_cooccurrence_graph(objects, visual_config)

    word2vec_model = load_word2vec(cfg.parameters.word2vec_path)
    similar_words = find_similar_words(word2vec_model, unseen_objects, n_word2vec=200)

    # calc coverage and precision
    results = list()
    settings: List[BackoffSettings] = [
        # BackoffSettings(use_direct=True),
        # BackoffSettings(use_word2vec=True),
        # BackoffSettings(use_hypernyms=True),
        # BackoffSettings(use_hyponyms=True),
        # BackoffSettings(use_head_noun=True),
        # BackoffSettings(use_direct=True, use_word2vec=True),
        BackoffSettings(use_direct=True, use_word2vec=True, use_hypernyms=True),
        # BackoffSettings(use_direct=True, use_hypernyms=True),
        # BackoffSettings(use_direct=True, use_hyponyms=True),
        # BackoffSettings(use_direct=True, use_head_noun=True),
        # BackoffSettings(use_direct=True, use_hyponyms=True)
    ]
    golds = [p.larger for p in test_pairs]

    for setting in settings:
        preds = list()
        fractions_larger = list()
        notes = list()

        prop = VisualPropagation(G, config.visual_config)
        logger.info(f'\nRunning for setting {setting.print()}')
        comparer = Comparer(prop, setting, similar_words, objects)
        for test_pair in tqdm.tqdm(test_pairs):
            # TODO return confidence; use the higher one
            res_visual, fraction_larger, note = comparer.compare_visual_with_backoff(test_pair)
            fractions_larger.append(fraction_larger)
            preds.append(res_visual)
            notes.append(note)

        with open(f'visual_comparison_predictions_{setting.print()}.pkl', 'wb') as f:
            pickle.dump(list(zip(preds, fractions_larger, notes)), f)

        useful_counts = comparer.useful_paths_count
        tr = SymmetricalLogTransform(base=10, linthresh=1, linscale=1)
        ss = tr.transform([0., max(useful_counts) + 1])
        bins = tr.inverted().transform(np.linspace(*ss, num=100))
        fig, ax = plt.subplots()
        plt.hist(useful_counts, bins=bins)
        plt.xlabel('Number of useful paths')
        ax.set_xscale('symlog')
        plt.savefig(f'useful_paths{setting.print()}.png')

        useful_counts = np.array(useful_counts)
        logger.info(f'Number of objects with no useful path: {len(np.extract(useful_counts == 0, useful_counts))}')
        logger.info(f'Not recog count: {comparer.not_recognized_count}')

        logger.info(f'Total number of test cases: {len(golds)}')
        coverage, selectivity = coverage_accuracy_relational(golds, preds)
        logger.info(f'Coverage: {coverage}')
        logger.info(f'selectivity: {selectivity}')

        results.append(RelationalResult(setting.print(), selectivity, coverage))

        assert len(fractions_larger) == len(preds)
        corrects_not_none = list()
        diffs_not_none = list()
        for i, fraction_larger in enumerate(fractions_larger):
            gold = golds[i]
            res = preds[i]
            if fraction_larger is not None and fraction_larger != 0.5:
                fraction_larger_centered = fraction_larger - .5
                corrects_not_none.append(gold == res)
                diffs_not_none.append(abs(fraction_larger_centered))
        # TODO do something special for when fraction_larger_centered == 0

        regr_linear = Ridge(alpha=1.0)
        regr_linear.fit(np.reshape(diffs_not_none, (-1, 1)), corrects_not_none)
        with open('visual_confidence_model.pkl', 'wb') as f:
            pickle.dump(regr_linear, f)

        fig, ax = plt.subplots()
        bin_means, bin_edges, binnumber = stats.binned_statistic(diffs_not_none, corrects_not_none, 'mean',
                                                                 bins=20)
        bin_counts, _, _ = stats.binned_statistic(diffs_not_none, corrects_not_none, 'count',
                                                  bins=20)
        x = np.linspace(min(diffs_not_none), max(diffs_not_none), 500)
        X = np.reshape(x, (-1, 1))
        plt.plot(x, regr_linear.predict(X), '-', label='linear ridge regression')

        minc = min(bin_counts)
        maxc = max(bin_counts)
        norm = colors.SymLogNorm(vmin=minc, vmax=maxc, linthresh=1)
        bin_counts_normalized = [norm(c) for c in bin_counts]
        logger.info(f'counts, norm: {list(zip(bin_counts, bin_counts_normalized))}')
        viridis = cm.get_cmap('viridis', 20)

        mins = bin_edges[:-1]
        maxs = bin_edges[1:]
        mask = ~np.isnan(bin_means)
        plt.hlines(np.extract(mask, bin_means), np.extract(mask, mins), np.extract(mask, maxs),
                   colors=viridis(np.extract(mask, bin_counts_normalized)), lw=5,
                   label='binned statistic of data')
        sm = plt.cm.ScalarMappable(cmap=viridis, norm=norm)
        ticks = [10**1.5, 10**1.75, 10**2, 10**2.5]
        colorbar = plt.colorbar(sm, ticks=ticks)
        colorbar.ax.set_yticklabels(['10^1.5', '10^1.75', '10^2', '10^2.5'])
        colorbar.set_label('bin count')
        plt.ylim(-0.05, 1.05)
        plt.legend()
        plt.xlabel('Absolute fraction_larger')
        plt.ylabel('Selectivity')
        ax.set_xscale('linear')
        plt.savefig('fraction_larger_selectivity_linear.png')
        plt.show()

        correlation, _ = pearsonr(diffs_not_none, corrects_not_none)
        logger.info(f'Pearsons correlation: {correlation}')

        correlation_spearman, _ = spearmanr(np.array(diffs_not_none), b=np.array(corrects_not_none))
        logger.info(f'Spearman correlation: {correlation_spearman}')

    results_df = pd.DataFrame(results)
    results_df.to_csv('results_visual_backoff.csv')
コード例 #6
0
def main():
    with open("config.yml", "r") as ymlfile:
        cfg = Box(yaml.safe_load(ymlfile))
        # cfg = Box(yaml.safe_load(ymlfile), default_box=True, default_box_attr=None)

    # TODO check whether the objects aren't in the bootstrapped objects
    visual_config = VisualConfig(cfg.path.vg_objects,
                                 cfg.path.vg_objects_anchors)
    config = Config(cfg, visual_config)

    input: DataFrame = pd.read_csv(cfg.path.dev)
    input = input.astype({'object': str})
    unseen_objects = list(input['object'])
    logger.info(f'Unseen objects: {unseen_objects}')

    visual_config = config.visual_config
    objects = list(visual_config.entity_to_synsets.keys())
    logger.info(f'Objects: {objects}')
    G = build_cooccurrence_graph(objects, visual_config)

    with open(cfg.path.final_seeds_cache) as f:
        numeric_seeds = json.load(f)

    numeric_seeds = dict((key.strip().replace(' ', '_'), value)
                         for (key, value) in numeric_seeds.items())
    del numeric_seeds[
        'rhine']  # There is a 'rhine' in VG, which was included in VG as the river. fixing this manually,
    # since it's in a lot of results

    point_predictions = dict()
    point_predictions_evenly = dict()
    point_predictions_svm = dict()
    prop = VisualPropagation(G, config.visual_config)
    for unseen_object in unseen_objects:
        unseen_object = unseen_object.replace(' ', '_')
        logger.info(f'Processing {unseen_object}')
        if unseen_object not in objects:
            logger.info(f'{unseen_object} not in visuals')
            point_predictions[unseen_object.replace('_', ' ')] = None
            point_predictions_evenly[unseen_object.replace('_', ' ')] = None
            point_predictions_svm[unseen_object.replace('_', ' ')] = None
            continue
        none_count = 0
        lower_bounds = set()
        upper_bounds = set()
        for numeric_seed in tqdm.tqdm(numeric_seeds.keys()):
            pair = Pair(unseen_object, numeric_seed)
            if pair.both_in_list(objects):
                fraction_larger, _ = prop.compare_pair(pair)
                if fraction_larger is None:
                    none_count += 1
                    continue
                if fraction_larger < .5:
                    upper_bounds.add(numeric_seed)
                if fraction_larger > .5:
                    lower_bounds.add(numeric_seed)
                logger.debug(
                    f'{pair.e1} {pair.e2} fraction larger: {fraction_larger}')
            else:
                logger.debug(
                    f'{pair.e1} or {pair.e2} not in VG. Objects: {objects}')

        lower_bounds_sizes = fill_sizes_list(lower_bounds, numeric_seeds)
        upper_bounds_sizes = fill_sizes_list(upper_bounds, numeric_seeds)

        # size = predict_size_with_bounds(lower_bounds_sizes, upper_bounds_sizes)
        size = iterativily_find_size(lower_bounds_sizes, upper_bounds_sizes)
        size_evenly = iterativily_find_size_evenly(lower_bounds_sizes,
                                                   upper_bounds_sizes)
        size_svm = predict_size_with_bounds(lower_bounds_sizes,
                                            upper_bounds_sizes)

        point_predictions[unseen_object.replace('_', ' ')] = size
        point_predictions_evenly[unseen_object.replace('_', ' ')] = size_evenly
        point_predictions_svm[unseen_object.replace('_', ' ')] = size_svm
        logger.info(f'\nObject: {unseen_object}')
        logger.info(f'Size: {size}')
        logger.info(f'Size evenly: {size_evenly}')
        logger.info(f'Size svm: {size_svm}')
        logger.info(
            f"None count: {none_count} out of {len(numeric_seeds.keys())}")
        logger.info(
            f"Lower bounds (n={len(lower_bounds)}): mean: {np.mean(lower_bounds_sizes)} median: {np.median(lower_bounds_sizes)}\n\t{lower_bounds}\n\t{lower_bounds_sizes}"
        )
        logger.info(
            f"Upper bounds (n={len(upper_bounds)}): mean: {np.mean(upper_bounds_sizes)} median: {np.median(upper_bounds_sizes)}\n\t{upper_bounds}\n\t{upper_bounds_sizes}"
        )

    with open(f'point_predictions_visual_ranges.pkl', 'wb') as f:
        pickle.dump(point_predictions, f)

    with open(f'point_predictions_visual_ranges_evenly.pkl', 'wb') as f:
        pickle.dump(point_predictions_evenly, f)

    with open(f'point_predictions_visual_ranges_svm.pkl', 'wb') as f:
        pickle.dump(point_predictions_svm, f)