def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model,
                 encoding_model, canopy2predictions, canopy2tree, trees):
    features = encoding_model.encode(all_records)
    if len(all_pids) > 1:
        grinch = Agglom(model, features, num_points=len(all_pids))
        grinch.build_dendrogram_hac()
        fc = grinch.flat_clustering(model.aux['threshold'])
        tree_id = len(trees)
        trees.append(grinch)
        for i in range(len(all_pids)):
            if all_canopies[i] not in canopy2predictions:
                canopy2predictions[all_canopies[i]] = [[], []]
                canopy2tree[all_canopies[i]] = tree_id
            canopy2predictions[all_canopies[i]][0].append(all_pids[i])
            canopy2predictions[all_canopies[i]][1].append(
                '%s-%s' % (all_canopies[i], fc[i]))
        return canopy2predictions
    else:
        fc = [0]
        for i in range(len(all_pids)):
            if all_canopies[i] not in canopy2predictions:
                canopy2predictions[all_canopies[i]] = [[], []]
                canopy2tree[all_canopies[i]] = None
            canopy2predictions[all_canopies[i]][0].append(all_pids[i])
            canopy2predictions[all_canopies[i]][1].append(
                '%s-%s' % (all_canopies[i], fc[i]))
        return canopy2predictions
def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model,
                 encoding_model, canopy2predictions, canopy2tree, trees,
                 pids_list):
    """

    :param all_pids: Ids of the records we are running on
    :param all_lbls: Labels (or blank labels if they are not available)
    :param all_records: The records (i.e. features - python strings of the inventor name, assignee etc)
    :param all_canopies: The canopies of each record (right now this is always all the same canopy)
    :param model: linear scoring / similarity model which takes in two feature vectors and gives a similarity
    :param encoding_model: mapping from all_records gives you something that is input to model
    :param canopy2predictions: where results are stored
    :param canopy2tree: where results are stored
    :param trees: where results are stored
    :param pids_list: where results are stored
    :return:
    """
    # extracting features
    features = encoding_model.encode(all_records)
    if len(all_pids) > 1:
        # running clustering
        grinch = Agglom(model, features, num_points=len(all_pids))
        grinch.build_dendrogram_hac()
        fc = grinch.flat_clustering(model.aux['threshold'])
        # storing the state of the clustering for the intremental setting
        tree_id = len(trees)
        # store the tree that is build for this canopy
        trees.append(grinch)
        pids_list.append(all_pids)
        for i in range(len(all_pids)):
            # record mapping from canopy to the tree id
            canopy2tree[all_canopies[i]] = tree_id
            if all_canopies[i] not in canopy2predictions:
                canopy2predictions[all_canopies[i]] = [[], []]
                canopy2tree[all_canopies[i]] = tree_id
            # save predictions (used in the non incremental setting)
            canopy2predictions[all_canopies[i]][0].append(all_pids[i])
            canopy2predictions[all_canopies[i]][1].append(
                '%s-%s' % (all_canopies[i], fc[i]))
        return canopy2predictions
    else:
        raise Exception('Must have non-singleton canopies')
        fc = [0]
        for i in range(len(all_pids)):
            if all_canopies[i] not in canopy2predictions:
                canopy2predictions[all_canopies[i]] = [[], []]
                canopy2tree[all_canopies[i]] = None
            canopy2predictions[all_canopies[i]][0].append(all_pids[i])
            canopy2predictions[all_canopies[i]][1].append(
                '%s-%s' % (all_canopies[i], fc[i]))
        return canopy2predictions
示例#3
0
def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model,
                 encoding_model, canopy2predictions, canopy2tree, trees,
                 pids_list, canopy_list):
    features = encoding_model.encode(all_records)
    grinch = Agglom(model,
                    features,
                    num_points=len(all_pids),
                    min_allowable_sim=0)
    grinch.build_dendrogram_hac()
    fc = grinch.flat_clustering(model.aux['threshold'])
    logging.info(
        'run_on_batch - threshold %s, linkages: min %s, max %s, avg %s, std %s',
        model.aux['threshold'], np.min(grinch.all_thresholds()),
        np.max(grinch.all_thresholds()), np.mean(grinch.all_thresholds()),
        np.std(grinch.all_thresholds()))
    tree_id = len(trees)
    trees.append(grinch)
    pids_list.append(all_pids)
    canopy_list.append(all_canopies)
    for i in range(len(all_pids)):
        if all_canopies[i] not in canopy2predictions:
            canopy2predictions[all_canopies[i]] = [[], []]
            canopy2tree[all_canopies[i]] = tree_id
        canopy2predictions[all_canopies[i]][0].append(all_pids[i])
        canopy2predictions[all_canopies[i]][1].append('%s-%s' %
                                                      (all_canopies[i], fc[i]))
    return canopy2predictions
def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model,
                 encoding_model, canopy2predictions):
    features = encoding_model.encode(all_records)
    # grinch = WeightedMultiFeatureGrinch(model, features, num_points=len(all_pids), max_nodes=3 * len(all_pids))
    grinch = Agglom(model, features, num_points=len(all_pids))
    grinch.build_dendrogram_hac()
    # grinch.get_score_batch(grinch.all_valid_internal_nodes())
    fc = grinch.flat_clustering(model.aux['threshold'])
    for i in range(len(all_pids)):
        if all_canopies[i] not in canopy2predictions:
            canopy2predictions[all_canopies[i]] = [[], []]
        canopy2predictions[all_canopies[i]][0].append(all_pids[i])
        canopy2predictions[all_canopies[i]][1].append('%s-%s' %
                                                      (all_canopies[i], fc[i]))
    return canopy2predictions
示例#5
0
    def dev_eval(self, datasets):
        logging.info('dev eval using %s datasets', len(datasets))
        trees = []
        gold_clustering = []
        dataset_names = []
        for idx, (dataset_name, dataset) in enumerate(datasets.items()):
            pids, lbls, records, features = dataset[0], dataset[1], dataset[2], dataset[3]
            logging.info('Running on dev dataset %s of %s | %s with %s points', idx, len(datasets), dataset_name,
                         len(pids))
            if len(pids) > 0:
                grinch = Agglom(self.grinch.model, features, num_points=len(pids))
                grinch.build_dendrogram_hac()
                trees.append(grinch)
                gold_clustering.extend(lbls)
                dataset_names.append(dataset_name)
        eval_ids = [i for i in range(len(gold_clustering)) if gold_clustering[i] != '-1']
        thresholds = np.sort(np.squeeze(self.get_thresholds(trees, self.num_thresholds)))
        scores_per_threshold = []
        os.makedirs(os.path.join(self.outdir, 'dev'), exist_ok=True)
        dev_out_f = open(os.path.join(self.outdir, 'dev', 'dev_%s.tsv' % self.global_step), 'w')
        for thres in thresholds:
            pred_clustering = []
            for idx, t in enumerate(trees):
                fc = t.flat_clustering(thres)
                pred_clustering.extend(['%s-%s' % (dataset_names[idx], x) for x in fc])
            metrics = eval_micro_pw_f1([pred_clustering[x] for x in eval_ids], [gold_clustering[x] for x in eval_ids])
            scores_per_threshold.append(metrics)
            logging.info('[dev] threshold %s | %s', thres, "|".join(['%s=%s' % (k, v) for k, v in metrics.items()]))

        arg_best_f1 = max([x for x in range(len(scores_per_threshold))],
                          key=lambda x: scores_per_threshold[x]['micro_pw_f1'])
        for idx, t in enumerate(trees):
            dataset = datasets[dataset_names[idx]]
            pids, lbls, records, features = dataset[0], dataset[1], dataset[2], dataset[3]
            fc = t.flat_clustering(thresholds[arg_best_f1])
            for j in range(len(records)):
                dev_out_f.write("%s\n" % records[j].pretty_tsv('%s-%s' % (dataset_names[idx], fc[j]), lbls[j]))
        metrics = scores_per_threshold[arg_best_f1]
        logging.info('[dev] best threshold %s | %s', thresholds[arg_best_f1],
                     "|".join(['%s=%s' % (k, v) for k, v in metrics.items()]))
        dev_out_f.close()
        if metrics['micro_pw_f1'] > self.best_f1:
            logging.info('new best f1 %s > %s', metrics['micro_pw_f1'], self.best_f1)
            self.model.aux['threshold'] = thresholds[arg_best_f1]
            self.best_f1 = metrics['micro_pw_f1']
            self.save('best')
        return thresholds[arg_best_f1], metrics
示例#6
0
def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model,
                 encoding_model, canopy2predictions):
    features = encoding_model.encode(all_records)
    grinch = Agglom(model,
                    features,
                    num_points=len(all_pids),
                    min_allowable_sim=0)
    grinch.build_dendrogram_hac()
    fc = grinch.flat_clustering(model.aux['threshold'])
    # import pdb
    # pdb.set_trace()
    for i in range(len(all_pids)):
        if all_canopies[i] not in canopy2predictions:
            canopy2predictions[all_canopies[i]] = [[], []]
        canopy2predictions[all_canopies[i]][0].append(all_pids[i])
        canopy2predictions[all_canopies[i]][1].append('%s-%s' %
                                                      (all_canopies[i], fc[i]))
    return canopy2predictions
def main(argv):
    logging.info('Running clustering with argv %s', str(argv))
    wandb.init(project="%s-%s" % (FLAGS.exp_name, FLAGS.dataset_name))
    wandb.config.update(flags.FLAGS)

    outdir = os.path.join(FLAGS.outprefix, wandb.env.get_project(),
                          os.environ.get(wandb.env.SWEEP_ID, 'solo'),
                          wandb.env.get_run())

    logging.info('outdir %s', outdir)

    os.makedirs(outdir, exist_ok=True)

    os.makedirs(FLAGS.cache_dir, exist_ok=True)

    need_to_load_train = False
    train_cache_file = os.path.join(
        FLAGS.cache_dir,
        'train-%s.pkl' % os.path.basename(FLAGS.training_data))
    dev_cache_file = os.path.join(
        FLAGS.cache_dir, 'dev-%s.pkl' % os.path.basename(FLAGS.dev_data))
    if os.path.exists(train_cache_file):
        logging.info('Using cached training data! %s', train_cache_file)
        with open(train_cache_file, 'rb') as fin:
            training_collection = pickle.load(fin)
            all_training_pids, all_training_labels, all_training_records, train_datasets = training_collection
    else:
        need_to_load_train = True

    need_to_load_dev = False
    if os.path.exists(dev_cache_file):
        logging.info('Using cached dev data! %s', dev_cache_file)
        with open(dev_cache_file, 'rb') as fin:
            dev_collection = pickle.load(fin)
            dev_datasets = dev_collection[0]
            dev_set_tmp = dict()
            for canopy in dev_datasets.keys():
                pids, lbls, records = dev_datasets[canopy]
                if len(records) >= FLAGS.min_dev_size and len(
                        records) <= FLAGS.max_dev_size:
                    dev_set_tmp[canopy] = [pids, lbls, records]
            dev_datasets = dev_set_tmp
    else:
        need_to_load_dev = True

    if need_to_load_train or need_to_load_dev:
        logging.info('need_to_load_train %s', need_to_load_train)
        logging.info('need_to_load_dev %s', need_to_load_dev)
        logging.info('Loading canopy dictionary...')
        with open(FLAGS.canopy2record_dict, 'rb') as fin:
            canopy2record = pickle.load(fin)

        with open(FLAGS.record2canopy_dict, 'rb') as fin:
            record2canopy = pickle.load(fin)

        logging.info('Loading rawinventor mentions....')
        rawinventor_cache_file = os.path.join(
            FLAGS.cache_dir, '%s.pkl' % os.path.basename(FLAGS.rawinventor))

        logging.info('Loading rawinventor mentions....')
        if os.path.exists(rawinventor_cache_file):
            logging.info('Using cache of rawinventor mentions %s',
                         rawinventor_cache_file)
            with open('%s.pkl' % FLAGS.rawinventor, 'rb') as fin:
                record_dict = pickle.load(fin)
        else:
            logging.info('No cache of rawinventor mentions %s',
                         '%s' % FLAGS.rawinventor)
            record_dict = dict()
            for m in load_inventor_mentions(FLAGS.rawinventor):
                record_dict[m.mention_id] = m
            with open('%s.pkl' % FLAGS.rawinventor, 'wb') as fout:
                pickle.dump(record_dict, fout)

        if need_to_load_train:
            logging.info('Loading training data ids and labels')
            training_point_ids = []
            training_labels = []
            train_datasets = dict()
            with open(FLAGS.training_data, 'r') as fin:
                for line in fin:
                    splt = line.strip().split('\t')

                    # check to see we have the point id in the record dictionary
                    if splt[0] not in record_dict:
                        logging.info("[train] missing record %s", splt[0])
                        continue

                    training_point_ids.append(splt[0])
                    training_labels.append(splt[1])
                    pt_canopy = record2canopy[splt[0]]
                    if pt_canopy not in train_datasets:
                        train_datasets[pt_canopy] = [[], []]
                    train_datasets[pt_canopy][0].append(splt[0])
                    train_datasets[pt_canopy][1].append(splt[1])
            training_set_tmp = dict()
            all_training_pids = []
            all_training_records = []
            all_training_labels = []
            for canopy, dataset in train_datasets.items():
                pids, lbls = dataset[0], dataset[1]
                records = [record_dict[x] for x in pids]
                all_training_pids.extend(pids)
                all_training_records.extend(records)
                all_training_labels.extend(lbls)
                training_set_tmp[canopy] = [pids, lbls, records]
            train_datasets = training_set_tmp

            training_collection = [
                all_training_pids, all_training_labels, all_training_records,
                train_datasets
            ]
            with open(train_cache_file, 'wb') as fin:
                pickle.dump(training_collection, fin)

        if need_to_load_dev:
            logging.info('Loading dev data ids and labels')

            canopies = set()
            point2label = dict()
            with open(FLAGS.dev_data, 'r') as fin:
                for line in fin:
                    splt = line.strip().split('\t')

                    # check to see we have the point id in the record dictionary
                    if splt[0] not in record_dict:
                        logging.info("[dev] missing record %s", splt[0])
                        continue

                    point2label[splt[0]] = splt[1]
                    pt_canopy = record2canopy[splt[0]]
                    canopies.add(pt_canopy)

            dev_set_tmp = dict()
            for canopy in canopies:
                pids = canopy2record[canopy]
                lbls = [
                    point2label[x] if x in point2label else '-1' for x in pids
                ]
                records = [record_dict[x] for x in pids]
                if len(records) >= FLAGS.min_dev_size and len(
                        records) <= FLAGS.max_dev_size:
                    dev_set_tmp[canopy] = [pids, lbls, records]
            dev_datasets = dev_set_tmp

            dev_collection = [dev_datasets]
            with open(dev_cache_file, 'wb') as fin:
                pickle.dump(dev_collection, fin)

    logging.info('Number of train canopies: %s', len(train_datasets))
    logging.info('Number of dev canopies: %s', len(dev_datasets))

    encoding_model = InventorModel.from_flags(FLAGS)

    model = LinearAndRuleModel.from_encoding_model(encoding_model)

    logging.info('building features for the training model...')
    logging.info('len(all_training_records) = %s', len(all_training_records))

    features = encoding_model.encode(all_training_records)
    # grinch = WeightedMultiFeatureGrinch(model, features, num_points=len(all_training_pids), max_nodes=50000)
    grinch = Agglom(model, features, num_points=len(all_training_pids))
    # from tqdm import tqdm
    # for i in tqdm(range(len(all_training_pids)), 'Initializing Grinch w/ All Points.'):
    #     grinch.add_pt(i)

    logging.info('setting up features for dev datasets...')
    dev_sets = dict()
    for canopy, dataset in tqdm(dev_datasets.items(), 'dev data build'):
        pids, lbls, records = dataset[0], dataset[1], dataset[2]
        # logging.info('canopy %s - len(dataset) %s', canopy, len(records))
        dev_sets[canopy] = [
            pids, lbls, records,
            encoding_model.encode(records)
        ]

    trainer = Trainer(outdir=outdir,
                      model=model,
                      encoding_model=encoding_model,
                      grinch=grinch,
                      pids=all_training_pids,
                      labels=all_training_labels,
                      points=all_training_records,
                      dev_data=dev_sets,
                      num_samples=FLAGS.num_samples,
                      num_negatives=FLAGS.num_negatives,
                      batch_size=FLAGS.batch_size,
                      dev_every=FLAGS.dev_every,
                      epochs=FLAGS.epochs,
                      lr=FLAGS.lr,
                      weight_decay=0.0,
                      margin=FLAGS.margin,
                      num_thresholds=FLAGS.num_thresholds,
                      max_dev_size=FLAGS.max_dev_canopy_size,
                      max_dev_canopies=FLAGS.max_num_dev_canopies)
    trainer.train()
示例#8
0
def main(argv):
    logging.info('Running clustering with argv %s', str(argv))
    wandb.init(project="%s-%s" % (FLAGS.exp_name, FLAGS.dataset_name))
    wandb.config.update(flags.FLAGS)

    outdir = os.path.join(FLAGS.outprefix, wandb.env.get_project(),
                          os.environ.get(wandb.env.SWEEP_ID, 'solo'),
                          wandb.env.get_run())

    logging.info('outdir %s', outdir)

    os.makedirs(outdir, exist_ok=True)

    os.makedirs(FLAGS.cache_dir, exist_ok=True)

    need_to_load_train = False
    dev_cache_file = os.path.join(
        FLAGS.cache_dir, 'dev-%s.pkl' % os.path.basename(FLAGS.dev_data))

    need_to_load_dev = False
    if os.path.exists(dev_cache_file):
        logging.info('Using cached dev data! %s', dev_cache_file)
        with open(dev_cache_file, 'rb') as fin:
            dev_collection = pickle.load(fin)
            dev_datasets = dev_collection[0]
            dev_set_tmp = dict()
            for canopy in dev_datasets.keys():
                pids, lbls, records = dev_datasets[canopy]
                if len(records) >= FLAGS.min_dev_size and len(
                        records) <= FLAGS.max_dev_size:
                    dev_set_tmp[canopy] = [pids, lbls, records]
            dev_datasets = dev_set_tmp
    else:
        need_to_load_dev = True

    if need_to_load_train or need_to_load_dev:
        logging.info('need_to_load_train %s', need_to_load_train)
        logging.info('need_to_load_dev %s', need_to_load_dev)
        logging.info('Loading canopy dictionary...')
        with open(FLAGS.canopy2record_dict, 'rb') as fin:
            canopy2record = pickle.load(fin)

        with open(FLAGS.record2canopy_dict, 'rb') as fin:
            record2canopy = pickle.load(fin)

        logging.info('Loading rawinventor mentions....')
        rawinventor_cache_file = os.path.join(
            FLAGS.cache_dir, '%s.pkl' % os.path.basename(FLAGS.rawinventor))

        logging.info('Loading rawinventor mentions....')
        if os.path.exists(rawinventor_cache_file):
            logging.info('Using cache of rawinventor mentions %s',
                         rawinventor_cache_file)
            with open('%s.pkl' % FLAGS.rawinventor, 'rb') as fin:
                record_dict = pickle.load(fin)
        else:
            logging.info('No cache of rawinventor mentions %s',
                         '%s' % FLAGS.rawinventor)
            record_dict = dict()
            for m in load_inventor_mentions(FLAGS.rawinventor):
                record_dict[m.mention_id] = m
            with open('%s.pkl' % FLAGS.rawinventor, 'wb') as fout:
                pickle.dump(record_dict, fout)

        if need_to_load_dev:
            logging.info('Loading dev data ids and labels')

            canopies = set()
            point2label = dict()
            with open(FLAGS.dev_data, 'r') as fin:
                for line in fin:
                    splt = line.strip().split('\t')

                    # check to see we have the point id in the record dictionary
                    if splt[0] not in record_dict:
                        logging.info("[dev] missing record %s", splt[0])
                        continue

                    point2label[splt[0]] = splt[1]
                    pt_canopy = record2canopy[splt[0]]
                    canopies.add(pt_canopy)

            dev_set_tmp = dict()
            for canopy in canopies:
                pids = canopy2record[canopy]
                lbls = [
                    point2label[x] if x in point2label else '-1' for x in pids
                ]
                records = [record_dict[x] for x in pids]
                if len(records) >= FLAGS.min_dev_size and len(
                        records) <= FLAGS.max_dev_size:
                    dev_set_tmp[canopy] = [pids, lbls, records]
            dev_datasets = dev_set_tmp

            dev_collection = [dev_datasets]
            with open(dev_cache_file, 'wb') as fin:
                pickle.dump(dev_collection, fin)

    logging.info('Number of dev canopies: %s', len(dev_datasets))

    encoding_model = InventorModel.from_flags(FLAGS)

    model = torch.load(FLAGS.model_path)

    logging.info('setting up features for dev datasets...')
    dev_sets = dict()
    for canopy, dataset in tqdm(dev_datasets.items(), 'dev data build'):
        pids, lbls, records = dataset[0], dataset[1], dataset[2]
        # logging.info('canopy %s - len(dataset) %s', canopy, len(records))
        dev_sets[canopy] = [
            pids, lbls, records,
            encoding_model.encode(records)
        ]

    logging.info('dev eval using %s datasets', len(dev_sets))
    trees = []
    gold_clustering = []
    dataset_names = []
    for idx, (dataset_name, dataset) in enumerate(dev_sets.items()):
        pids, lbls, records, features = dataset[0], dataset[1], dataset[
            2], dataset[3]
        logging.info('Running on dev dataset %s of %s | %s with %s points',
                     idx, len(dev_sets), dataset_name, len(pids))
        if len(pids) > 0:
            grinch = Agglom(model, features, num_points=len(pids))
            grinch.build_dendrogram_hac()
            trees.append(grinch)
            gold_clustering.extend(lbls)
            dataset_names.append(dataset_name)
    eval_ids = [
        i for i in range(len(gold_clustering)) if gold_clustering[i] != '-1'
    ]
    thresholds = [model.aux['threshold']]
    scores_per_threshold = []
    os.makedirs(os.path.join(outdir, 'dev'), exist_ok=True)
    dev_out_f = open(os.path.join(outdir, 'dev', 'dev.tsv'), 'w')
    pred_clustering = []
    for idx, t in enumerate(trees):
        fc = t.flat_clustering(model.aux['threshold'])
        pred_clustering.extend(['%s-%s' % (dataset_names[idx], x) for x in fc])
    metrics = eval_micro_pw_f1([pred_clustering[x] for x in eval_ids],
                               [gold_clustering[x] for x in eval_ids])
    scores_per_threshold.append(metrics)
    logging.info('[dev] threshold %s | %s', model.aux['threshold'],
                 "|".join(['%s=%s' % (k, v) for k, v in metrics.items()]))

    for idx, t in enumerate(trees):
        dataset = dev_sets[dataset_names[idx]]
        pids, lbls, records, features = dataset[0], dataset[1], dataset[
            2], dataset[3]
        fc = t.flat_clustering(model.aux['threshold'])
        for j in range(len(records)):
            dev_out_f.write(
                "%s\n" %
                records[j].pretty_tsv('%s-%s' %
                                      (dataset_names[idx], fc[j]), lbls[j]))
    dev_out_f.close()