def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model, encoding_model, canopy2predictions, canopy2tree, trees): features = encoding_model.encode(all_records) if len(all_pids) > 1: grinch = Agglom(model, features, num_points=len(all_pids)) grinch.build_dendrogram_hac() fc = grinch.flat_clustering(model.aux['threshold']) tree_id = len(trees) trees.append(grinch) for i in range(len(all_pids)): if all_canopies[i] not in canopy2predictions: canopy2predictions[all_canopies[i]] = [[], []] canopy2tree[all_canopies[i]] = tree_id canopy2predictions[all_canopies[i]][0].append(all_pids[i]) canopy2predictions[all_canopies[i]][1].append( '%s-%s' % (all_canopies[i], fc[i])) return canopy2predictions else: fc = [0] for i in range(len(all_pids)): if all_canopies[i] not in canopy2predictions: canopy2predictions[all_canopies[i]] = [[], []] canopy2tree[all_canopies[i]] = None canopy2predictions[all_canopies[i]][0].append(all_pids[i]) canopy2predictions[all_canopies[i]][1].append( '%s-%s' % (all_canopies[i], fc[i])) return canopy2predictions
def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model, encoding_model, canopy2predictions, canopy2tree, trees, pids_list): """ :param all_pids: Ids of the records we are running on :param all_lbls: Labels (or blank labels if they are not available) :param all_records: The records (i.e. features - python strings of the inventor name, assignee etc) :param all_canopies: The canopies of each record (right now this is always all the same canopy) :param model: linear scoring / similarity model which takes in two feature vectors and gives a similarity :param encoding_model: mapping from all_records gives you something that is input to model :param canopy2predictions: where results are stored :param canopy2tree: where results are stored :param trees: where results are stored :param pids_list: where results are stored :return: """ # extracting features features = encoding_model.encode(all_records) if len(all_pids) > 1: # running clustering grinch = Agglom(model, features, num_points=len(all_pids)) grinch.build_dendrogram_hac() fc = grinch.flat_clustering(model.aux['threshold']) # storing the state of the clustering for the intremental setting tree_id = len(trees) # store the tree that is build for this canopy trees.append(grinch) pids_list.append(all_pids) for i in range(len(all_pids)): # record mapping from canopy to the tree id canopy2tree[all_canopies[i]] = tree_id if all_canopies[i] not in canopy2predictions: canopy2predictions[all_canopies[i]] = [[], []] canopy2tree[all_canopies[i]] = tree_id # save predictions (used in the non incremental setting) canopy2predictions[all_canopies[i]][0].append(all_pids[i]) canopy2predictions[all_canopies[i]][1].append( '%s-%s' % (all_canopies[i], fc[i])) return canopy2predictions else: raise Exception('Must have non-singleton canopies') fc = [0] for i in range(len(all_pids)): if all_canopies[i] not in canopy2predictions: canopy2predictions[all_canopies[i]] = [[], []] canopy2tree[all_canopies[i]] = None canopy2predictions[all_canopies[i]][0].append(all_pids[i]) canopy2predictions[all_canopies[i]][1].append( '%s-%s' % (all_canopies[i], fc[i])) return canopy2predictions
def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model, encoding_model, canopy2predictions, canopy2tree, trees, pids_list, canopy_list): features = encoding_model.encode(all_records) grinch = Agglom(model, features, num_points=len(all_pids), min_allowable_sim=0) grinch.build_dendrogram_hac() fc = grinch.flat_clustering(model.aux['threshold']) logging.info( 'run_on_batch - threshold %s, linkages: min %s, max %s, avg %s, std %s', model.aux['threshold'], np.min(grinch.all_thresholds()), np.max(grinch.all_thresholds()), np.mean(grinch.all_thresholds()), np.std(grinch.all_thresholds())) tree_id = len(trees) trees.append(grinch) pids_list.append(all_pids) canopy_list.append(all_canopies) for i in range(len(all_pids)): if all_canopies[i] not in canopy2predictions: canopy2predictions[all_canopies[i]] = [[], []] canopy2tree[all_canopies[i]] = tree_id canopy2predictions[all_canopies[i]][0].append(all_pids[i]) canopy2predictions[all_canopies[i]][1].append('%s-%s' % (all_canopies[i], fc[i])) return canopy2predictions
def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model, encoding_model, canopy2predictions): features = encoding_model.encode(all_records) # grinch = WeightedMultiFeatureGrinch(model, features, num_points=len(all_pids), max_nodes=3 * len(all_pids)) grinch = Agglom(model, features, num_points=len(all_pids)) grinch.build_dendrogram_hac() # grinch.get_score_batch(grinch.all_valid_internal_nodes()) fc = grinch.flat_clustering(model.aux['threshold']) for i in range(len(all_pids)): if all_canopies[i] not in canopy2predictions: canopy2predictions[all_canopies[i]] = [[], []] canopy2predictions[all_canopies[i]][0].append(all_pids[i]) canopy2predictions[all_canopies[i]][1].append('%s-%s' % (all_canopies[i], fc[i])) return canopy2predictions
def dev_eval(self, datasets): logging.info('dev eval using %s datasets', len(datasets)) trees = [] gold_clustering = [] dataset_names = [] for idx, (dataset_name, dataset) in enumerate(datasets.items()): pids, lbls, records, features = dataset[0], dataset[1], dataset[2], dataset[3] logging.info('Running on dev dataset %s of %s | %s with %s points', idx, len(datasets), dataset_name, len(pids)) if len(pids) > 0: grinch = Agglom(self.grinch.model, features, num_points=len(pids)) grinch.build_dendrogram_hac() trees.append(grinch) gold_clustering.extend(lbls) dataset_names.append(dataset_name) eval_ids = [i for i in range(len(gold_clustering)) if gold_clustering[i] != '-1'] thresholds = np.sort(np.squeeze(self.get_thresholds(trees, self.num_thresholds))) scores_per_threshold = [] os.makedirs(os.path.join(self.outdir, 'dev'), exist_ok=True) dev_out_f = open(os.path.join(self.outdir, 'dev', 'dev_%s.tsv' % self.global_step), 'w') for thres in thresholds: pred_clustering = [] for idx, t in enumerate(trees): fc = t.flat_clustering(thres) pred_clustering.extend(['%s-%s' % (dataset_names[idx], x) for x in fc]) metrics = eval_micro_pw_f1([pred_clustering[x] for x in eval_ids], [gold_clustering[x] for x in eval_ids]) scores_per_threshold.append(metrics) logging.info('[dev] threshold %s | %s', thres, "|".join(['%s=%s' % (k, v) for k, v in metrics.items()])) arg_best_f1 = max([x for x in range(len(scores_per_threshold))], key=lambda x: scores_per_threshold[x]['micro_pw_f1']) for idx, t in enumerate(trees): dataset = datasets[dataset_names[idx]] pids, lbls, records, features = dataset[0], dataset[1], dataset[2], dataset[3] fc = t.flat_clustering(thresholds[arg_best_f1]) for j in range(len(records)): dev_out_f.write("%s\n" % records[j].pretty_tsv('%s-%s' % (dataset_names[idx], fc[j]), lbls[j])) metrics = scores_per_threshold[arg_best_f1] logging.info('[dev] best threshold %s | %s', thresholds[arg_best_f1], "|".join(['%s=%s' % (k, v) for k, v in metrics.items()])) dev_out_f.close() if metrics['micro_pw_f1'] > self.best_f1: logging.info('new best f1 %s > %s', metrics['micro_pw_f1'], self.best_f1) self.model.aux['threshold'] = thresholds[arg_best_f1] self.best_f1 = metrics['micro_pw_f1'] self.save('best') return thresholds[arg_best_f1], metrics
def run_on_batch(all_pids, all_lbls, all_records, all_canopies, model, encoding_model, canopy2predictions): features = encoding_model.encode(all_records) grinch = Agglom(model, features, num_points=len(all_pids), min_allowable_sim=0) grinch.build_dendrogram_hac() fc = grinch.flat_clustering(model.aux['threshold']) # import pdb # pdb.set_trace() for i in range(len(all_pids)): if all_canopies[i] not in canopy2predictions: canopy2predictions[all_canopies[i]] = [[], []] canopy2predictions[all_canopies[i]][0].append(all_pids[i]) canopy2predictions[all_canopies[i]][1].append('%s-%s' % (all_canopies[i], fc[i])) return canopy2predictions
def main(argv): logging.info('Running clustering with argv %s', str(argv)) wandb.init(project="%s-%s" % (FLAGS.exp_name, FLAGS.dataset_name)) wandb.config.update(flags.FLAGS) outdir = os.path.join(FLAGS.outprefix, wandb.env.get_project(), os.environ.get(wandb.env.SWEEP_ID, 'solo'), wandb.env.get_run()) logging.info('outdir %s', outdir) os.makedirs(outdir, exist_ok=True) os.makedirs(FLAGS.cache_dir, exist_ok=True) need_to_load_train = False train_cache_file = os.path.join( FLAGS.cache_dir, 'train-%s.pkl' % os.path.basename(FLAGS.training_data)) dev_cache_file = os.path.join( FLAGS.cache_dir, 'dev-%s.pkl' % os.path.basename(FLAGS.dev_data)) if os.path.exists(train_cache_file): logging.info('Using cached training data! %s', train_cache_file) with open(train_cache_file, 'rb') as fin: training_collection = pickle.load(fin) all_training_pids, all_training_labels, all_training_records, train_datasets = training_collection else: need_to_load_train = True need_to_load_dev = False if os.path.exists(dev_cache_file): logging.info('Using cached dev data! %s', dev_cache_file) with open(dev_cache_file, 'rb') as fin: dev_collection = pickle.load(fin) dev_datasets = dev_collection[0] dev_set_tmp = dict() for canopy in dev_datasets.keys(): pids, lbls, records = dev_datasets[canopy] if len(records) >= FLAGS.min_dev_size and len( records) <= FLAGS.max_dev_size: dev_set_tmp[canopy] = [pids, lbls, records] dev_datasets = dev_set_tmp else: need_to_load_dev = True if need_to_load_train or need_to_load_dev: logging.info('need_to_load_train %s', need_to_load_train) logging.info('need_to_load_dev %s', need_to_load_dev) logging.info('Loading canopy dictionary...') with open(FLAGS.canopy2record_dict, 'rb') as fin: canopy2record = pickle.load(fin) with open(FLAGS.record2canopy_dict, 'rb') as fin: record2canopy = pickle.load(fin) logging.info('Loading rawinventor mentions....') rawinventor_cache_file = os.path.join( FLAGS.cache_dir, '%s.pkl' % os.path.basename(FLAGS.rawinventor)) logging.info('Loading rawinventor mentions....') if os.path.exists(rawinventor_cache_file): logging.info('Using cache of rawinventor mentions %s', rawinventor_cache_file) with open('%s.pkl' % FLAGS.rawinventor, 'rb') as fin: record_dict = pickle.load(fin) else: logging.info('No cache of rawinventor mentions %s', '%s' % FLAGS.rawinventor) record_dict = dict() for m in load_inventor_mentions(FLAGS.rawinventor): record_dict[m.mention_id] = m with open('%s.pkl' % FLAGS.rawinventor, 'wb') as fout: pickle.dump(record_dict, fout) if need_to_load_train: logging.info('Loading training data ids and labels') training_point_ids = [] training_labels = [] train_datasets = dict() with open(FLAGS.training_data, 'r') as fin: for line in fin: splt = line.strip().split('\t') # check to see we have the point id in the record dictionary if splt[0] not in record_dict: logging.info("[train] missing record %s", splt[0]) continue training_point_ids.append(splt[0]) training_labels.append(splt[1]) pt_canopy = record2canopy[splt[0]] if pt_canopy not in train_datasets: train_datasets[pt_canopy] = [[], []] train_datasets[pt_canopy][0].append(splt[0]) train_datasets[pt_canopy][1].append(splt[1]) training_set_tmp = dict() all_training_pids = [] all_training_records = [] all_training_labels = [] for canopy, dataset in train_datasets.items(): pids, lbls = dataset[0], dataset[1] records = [record_dict[x] for x in pids] all_training_pids.extend(pids) all_training_records.extend(records) all_training_labels.extend(lbls) training_set_tmp[canopy] = [pids, lbls, records] train_datasets = training_set_tmp training_collection = [ all_training_pids, all_training_labels, all_training_records, train_datasets ] with open(train_cache_file, 'wb') as fin: pickle.dump(training_collection, fin) if need_to_load_dev: logging.info('Loading dev data ids and labels') canopies = set() point2label = dict() with open(FLAGS.dev_data, 'r') as fin: for line in fin: splt = line.strip().split('\t') # check to see we have the point id in the record dictionary if splt[0] not in record_dict: logging.info("[dev] missing record %s", splt[0]) continue point2label[splt[0]] = splt[1] pt_canopy = record2canopy[splt[0]] canopies.add(pt_canopy) dev_set_tmp = dict() for canopy in canopies: pids = canopy2record[canopy] lbls = [ point2label[x] if x in point2label else '-1' for x in pids ] records = [record_dict[x] for x in pids] if len(records) >= FLAGS.min_dev_size and len( records) <= FLAGS.max_dev_size: dev_set_tmp[canopy] = [pids, lbls, records] dev_datasets = dev_set_tmp dev_collection = [dev_datasets] with open(dev_cache_file, 'wb') as fin: pickle.dump(dev_collection, fin) logging.info('Number of train canopies: %s', len(train_datasets)) logging.info('Number of dev canopies: %s', len(dev_datasets)) encoding_model = InventorModel.from_flags(FLAGS) model = LinearAndRuleModel.from_encoding_model(encoding_model) logging.info('building features for the training model...') logging.info('len(all_training_records) = %s', len(all_training_records)) features = encoding_model.encode(all_training_records) # grinch = WeightedMultiFeatureGrinch(model, features, num_points=len(all_training_pids), max_nodes=50000) grinch = Agglom(model, features, num_points=len(all_training_pids)) # from tqdm import tqdm # for i in tqdm(range(len(all_training_pids)), 'Initializing Grinch w/ All Points.'): # grinch.add_pt(i) logging.info('setting up features for dev datasets...') dev_sets = dict() for canopy, dataset in tqdm(dev_datasets.items(), 'dev data build'): pids, lbls, records = dataset[0], dataset[1], dataset[2] # logging.info('canopy %s - len(dataset) %s', canopy, len(records)) dev_sets[canopy] = [ pids, lbls, records, encoding_model.encode(records) ] trainer = Trainer(outdir=outdir, model=model, encoding_model=encoding_model, grinch=grinch, pids=all_training_pids, labels=all_training_labels, points=all_training_records, dev_data=dev_sets, num_samples=FLAGS.num_samples, num_negatives=FLAGS.num_negatives, batch_size=FLAGS.batch_size, dev_every=FLAGS.dev_every, epochs=FLAGS.epochs, lr=FLAGS.lr, weight_decay=0.0, margin=FLAGS.margin, num_thresholds=FLAGS.num_thresholds, max_dev_size=FLAGS.max_dev_canopy_size, max_dev_canopies=FLAGS.max_num_dev_canopies) trainer.train()
def main(argv): logging.info('Running clustering with argv %s', str(argv)) wandb.init(project="%s-%s" % (FLAGS.exp_name, FLAGS.dataset_name)) wandb.config.update(flags.FLAGS) outdir = os.path.join(FLAGS.outprefix, wandb.env.get_project(), os.environ.get(wandb.env.SWEEP_ID, 'solo'), wandb.env.get_run()) logging.info('outdir %s', outdir) os.makedirs(outdir, exist_ok=True) os.makedirs(FLAGS.cache_dir, exist_ok=True) need_to_load_train = False dev_cache_file = os.path.join( FLAGS.cache_dir, 'dev-%s.pkl' % os.path.basename(FLAGS.dev_data)) need_to_load_dev = False if os.path.exists(dev_cache_file): logging.info('Using cached dev data! %s', dev_cache_file) with open(dev_cache_file, 'rb') as fin: dev_collection = pickle.load(fin) dev_datasets = dev_collection[0] dev_set_tmp = dict() for canopy in dev_datasets.keys(): pids, lbls, records = dev_datasets[canopy] if len(records) >= FLAGS.min_dev_size and len( records) <= FLAGS.max_dev_size: dev_set_tmp[canopy] = [pids, lbls, records] dev_datasets = dev_set_tmp else: need_to_load_dev = True if need_to_load_train or need_to_load_dev: logging.info('need_to_load_train %s', need_to_load_train) logging.info('need_to_load_dev %s', need_to_load_dev) logging.info('Loading canopy dictionary...') with open(FLAGS.canopy2record_dict, 'rb') as fin: canopy2record = pickle.load(fin) with open(FLAGS.record2canopy_dict, 'rb') as fin: record2canopy = pickle.load(fin) logging.info('Loading rawinventor mentions....') rawinventor_cache_file = os.path.join( FLAGS.cache_dir, '%s.pkl' % os.path.basename(FLAGS.rawinventor)) logging.info('Loading rawinventor mentions....') if os.path.exists(rawinventor_cache_file): logging.info('Using cache of rawinventor mentions %s', rawinventor_cache_file) with open('%s.pkl' % FLAGS.rawinventor, 'rb') as fin: record_dict = pickle.load(fin) else: logging.info('No cache of rawinventor mentions %s', '%s' % FLAGS.rawinventor) record_dict = dict() for m in load_inventor_mentions(FLAGS.rawinventor): record_dict[m.mention_id] = m with open('%s.pkl' % FLAGS.rawinventor, 'wb') as fout: pickle.dump(record_dict, fout) if need_to_load_dev: logging.info('Loading dev data ids and labels') canopies = set() point2label = dict() with open(FLAGS.dev_data, 'r') as fin: for line in fin: splt = line.strip().split('\t') # check to see we have the point id in the record dictionary if splt[0] not in record_dict: logging.info("[dev] missing record %s", splt[0]) continue point2label[splt[0]] = splt[1] pt_canopy = record2canopy[splt[0]] canopies.add(pt_canopy) dev_set_tmp = dict() for canopy in canopies: pids = canopy2record[canopy] lbls = [ point2label[x] if x in point2label else '-1' for x in pids ] records = [record_dict[x] for x in pids] if len(records) >= FLAGS.min_dev_size and len( records) <= FLAGS.max_dev_size: dev_set_tmp[canopy] = [pids, lbls, records] dev_datasets = dev_set_tmp dev_collection = [dev_datasets] with open(dev_cache_file, 'wb') as fin: pickle.dump(dev_collection, fin) logging.info('Number of dev canopies: %s', len(dev_datasets)) encoding_model = InventorModel.from_flags(FLAGS) model = torch.load(FLAGS.model_path) logging.info('setting up features for dev datasets...') dev_sets = dict() for canopy, dataset in tqdm(dev_datasets.items(), 'dev data build'): pids, lbls, records = dataset[0], dataset[1], dataset[2] # logging.info('canopy %s - len(dataset) %s', canopy, len(records)) dev_sets[canopy] = [ pids, lbls, records, encoding_model.encode(records) ] logging.info('dev eval using %s datasets', len(dev_sets)) trees = [] gold_clustering = [] dataset_names = [] for idx, (dataset_name, dataset) in enumerate(dev_sets.items()): pids, lbls, records, features = dataset[0], dataset[1], dataset[ 2], dataset[3] logging.info('Running on dev dataset %s of %s | %s with %s points', idx, len(dev_sets), dataset_name, len(pids)) if len(pids) > 0: grinch = Agglom(model, features, num_points=len(pids)) grinch.build_dendrogram_hac() trees.append(grinch) gold_clustering.extend(lbls) dataset_names.append(dataset_name) eval_ids = [ i for i in range(len(gold_clustering)) if gold_clustering[i] != '-1' ] thresholds = [model.aux['threshold']] scores_per_threshold = [] os.makedirs(os.path.join(outdir, 'dev'), exist_ok=True) dev_out_f = open(os.path.join(outdir, 'dev', 'dev.tsv'), 'w') pred_clustering = [] for idx, t in enumerate(trees): fc = t.flat_clustering(model.aux['threshold']) pred_clustering.extend(['%s-%s' % (dataset_names[idx], x) for x in fc]) metrics = eval_micro_pw_f1([pred_clustering[x] for x in eval_ids], [gold_clustering[x] for x in eval_ids]) scores_per_threshold.append(metrics) logging.info('[dev] threshold %s | %s', model.aux['threshold'], "|".join(['%s=%s' % (k, v) for k, v in metrics.items()])) for idx, t in enumerate(trees): dataset = dev_sets[dataset_names[idx]] pids, lbls, records, features = dataset[0], dataset[1], dataset[ 2], dataset[3] fc = t.flat_clustering(model.aux['threshold']) for j in range(len(records)): dev_out_f.write( "%s\n" % records[j].pretty_tsv('%s-%s' % (dataset_names[idx], fc[j]), lbls[j])) dev_out_f.close()