def test_clutrr_v3(): embedding_size = 20 batch_size = 8 torch.manual_seed(0) triples, hops = [], [] for i in range(32): triples += [(f'a{i}', 'p', f'b{i}'), (f'b{i}', 'q', f'c{i}')] hops += [(f'a{i}', 'r', f'c{i}')] entity_lst = sorted({s for (s, _, _) in triples + hops} | {o for (_, _, o) in triples + hops}) predicate_lst = sorted({p for (_, p, _) in triples + hops}) nb_entities, nb_predicates = len(entity_lst), len(predicate_lst) entity_to_index = {e: i for i, e in enumerate(entity_lst)} predicate_to_index = {p: i for i, p in enumerate(predicate_lst)} kernel = GaussianKernel(slope=None) entity_embeddings = nn.Embedding(nb_entities, embedding_size, sparse=True) predicate_embeddings = nn.Embedding(nb_predicates, embedding_size, sparse=True) # _hops = LinearReformulator(2, embedding_size) _hops = AttentiveReformulator(2, predicate_embeddings) model = NeuralKB(kernel=kernel, scoring_type='concat') hoppy = Hoppy(model, hops_lst=[(_hops, False)], depth=1) params = [ p for p in hoppy.parameters() if not torch.equal(p, entity_embeddings.weight) and not torch.equal(p, predicate_embeddings.weight) ] for tensor in params: print(f'\t{tensor.size()}\t{tensor.device}') loss_function = nn.BCELoss() optimizer = optim.Adagrad(params, lr=0.1) hops_data = [] for i in range(64): hops_data += hops batches = make_batches(len(hops_data), batch_size) rs = np.random.RandomState() c, d = 0.0, 0.0 p_emb = predicate_embeddings( torch.from_numpy(np.array([predicate_to_index['p']]))) q_emb = predicate_embeddings( torch.from_numpy(np.array([predicate_to_index['q']]))) for batch_start, batch_end in batches: hops_batch = hops_data[batch_start:batch_end] s_lst = [s for (s, _, _) in hops_batch] p_lst = [p for (_, p, _) in hops_batch] o_lst = [o for (_, _, o) in hops_batch] nb_positives = len(s_lst) nb_negatives = nb_positives * 3 s_n_lst = rs.permutation(nb_entities)[:nb_negatives].tolist() nb_negatives = len(s_n_lst) o_n_lst = rs.permutation(nb_entities)[:nb_negatives].tolist() p_n_lst = list(islice(cycle(p_lst), nb_negatives)) xs_np = np.array([entity_to_index[s] for s in s_lst] + s_n_lst) xp_np = np.array([predicate_to_index[p] for p in p_lst + p_n_lst]) xo_np = np.array([entity_to_index[o] for o in o_lst] + o_n_lst) xs_emb = entity_embeddings(torch.from_numpy(xs_np)) xp_emb = predicate_embeddings(torch.from_numpy(xp_np)) xo_emb = entity_embeddings(torch.from_numpy(xo_np)) rel_emb = encode_relation(facts=triples, relation_embeddings=predicate_embeddings, relation_to_idx=predicate_to_index) arg1_emb, arg2_emb = encode_arguments( facts=triples, entity_embeddings=entity_embeddings, entity_to_idx=entity_to_index) facts = [rel_emb, arg1_emb, arg2_emb] scores = hoppy.score(xp_emb, xs_emb, xo_emb, facts=facts, entity_embeddings=entity_embeddings.weight) labels_np = np.zeros(xs_np.shape[0]) labels_np[:nb_positives] = 1 labels = torch.from_numpy(labels_np).float() # for s, p, o, l in zip(xs_np, xp_np, xo_np, labels): # print(s, p, o, l) loss = loss_function(scores, labels) hop_1_emb = hoppy.hops_lst[0][0].hops_lst[0](xp_emb) hop_2_emb = hoppy.hops_lst[0][0].hops_lst[1](xp_emb) c = kernel.pairwise(p_emb, hop_1_emb).mean().cpu().detach().numpy() d = kernel.pairwise(q_emb, hop_2_emb).mean().cpu().detach().numpy() print(c, d) loss.backward() optimizer.step() optimizer.zero_grad() assert c > 0.95 and d > 0.95
def main(argv): argparser = argparse.ArgumentParser('CLUTRR', formatter_class=argparse.ArgumentDefaultsHelpFormatter) train_path = test_path = "data/clutrr-emnlp/data_test/64.csv" argparser.add_argument('--train', action='store', type=str, default=train_path) argparser.add_argument('--test', nargs='+', type=str, default=[test_path]) # model params argparser.add_argument('--embedding-size', '-k', action='store', type=int, default=20) argparser.add_argument('--k-max', '-m', action='store', type=int, default=10) argparser.add_argument('--max-depth', '-d', action='store', type=int, default=2) argparser.add_argument('--test-max-depth', action='store', type=int, default=None) argparser.add_argument('--hops', nargs='+', type=str, default=['2', '2', '1R']) # training params argparser.add_argument('--epochs', '-e', action='store', type=int, default=100) argparser.add_argument('--learning-rate', '-l', action='store', type=float, default=0.1) argparser.add_argument('--batch-size', '-b', action='store', type=int, default=8) argparser.add_argument('--optimizer', '-o', action='store', type=str, default='adagrad', choices=['adagrad', 'adam', 'sgd']) argparser.add_argument('--seed', action='store', type=int, default=0) argparser.add_argument('--evaluate-every', '-V', action='store', type=int, default=32) argparser.add_argument('--N2', action='store', type=float, default=None) argparser.add_argument('--N3', action='store', type=float, default=None) argparser.add_argument('--entropy', '-E', action='store', type=float, default=None) argparser.add_argument('--scoring-type', '-s', action='store', type=str, default='min', choices=['concat', 'min']) argparser.add_argument('--tnorm', '-t', action='store', type=str, default='min', choices=['min', 'prod']) argparser.add_argument('--reformulator', '-r', action='store', type=str, default='linear', choices=['static', 'linear', 'attentive', 'memory', 'ntp']) argparser.add_argument('--nb-rules', '-R', action='store', type=int, default=4) argparser.add_argument('--GNTP-R', action='store', type=int, default=None) argparser.add_argument('--slope', '-S', action='store', type=float, default=None) argparser.add_argument('--init-size', '-i', action='store', type=float, default=1.0) argparser.add_argument('--init', action='store', type=str, default='uniform') argparser.add_argument('--ref-init', action='store', type=str, default='random') argparser.add_argument('--debug', '-D', action='store_true', default=False) argparser.add_argument('--load', action='store', type=str, default=None) argparser.add_argument('--save', action='store', type=str, default=None) args = argparser.parse_args(argv) train_path = args.train test_paths = args.test embedding_size = args.embedding_size k_max = args.k_max max_depth = args.max_depth test_max_depth = args.test_max_depth hops_str = args.hops nb_epochs = args.epochs learning_rate = args.learning_rate batch_size = args.batch_size optimizer_name = args.optimizer seed = args.seed evaluate_every = args.evaluate_every N2_weight = args.N2 N3_weight = args.N3 entropy_weight = args.entropy scoring_type = args.scoring_type tnorm_name = args.tnorm reformulator_name = args.reformulator nb_rules = args.nb_rules gntp_R = args.GNTP_R slope = args.slope init_size = args.init_size init_type = args.init ref_init_type = args.ref_init is_debug = args.debug load_path = args.load save_path = args.save np.random.seed(seed) random_state = np.random.RandomState(seed) torch.manual_seed(seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f'Device: {device}') if torch.cuda.is_available(): torch.set_default_tensor_type(torch.cuda.FloatTensor) data = Data(train_path=train_path, test_paths=test_paths) relation_to_predicate = data.relation_to_predicate predicate_to_relations = data.predicate_to_relations entity_lst, predicate_lst, relation_lst = data.entity_lst, data.predicate_lst, data.relation_lst nb_examples = len(data.train) nb_entities = len(entity_lst) nb_relations = len(relation_lst) entity_to_idx = {e: i for i, e in enumerate(entity_lst)} relation_to_idx = {r: i for i, r in enumerate(relation_lst)} kernel = GaussianKernel(slope=slope) entity_embeddings = nn.Embedding(nb_entities, embedding_size, sparse=False).to(device) nn.init.uniform_(entity_embeddings.weight, -1.0, 1.0) entity_embeddings.requires_grad = False relation_embeddings = nn.Embedding(nb_relations, embedding_size, sparse=False).to(device) if init_type in {'uniform'}: nn.init.uniform_(relation_embeddings.weight, -1.0, 1.0) relation_embeddings.weight.data *= init_size model = NeuralKB(kernel=kernel, scoring_type=scoring_type).to(device) memory = None def make_hop(s: str) -> Tuple[BaseReformulator, bool]: nonlocal memory if s.isdigit(): nb_hops, is_reversed = int(s), False else: nb_hops, is_reversed = int(s[:-1]), True res = None if reformulator_name in {'static'}: res = StaticReformulator(nb_hops, embedding_size, init_name=ref_init_type) elif reformulator_name in {'linear'}: res = LinearReformulator(nb_hops, embedding_size, init_name=ref_init_type) elif reformulator_name in {'attentive'}: res = AttentiveReformulator(nb_hops, relation_embeddings, init_name=ref_init_type) elif reformulator_name in {'memory'}: if memory is None: memory = MemoryReformulator.Memory(nb_hops, nb_rules, embedding_size, init_name=ref_init_type) res = MemoryReformulator(memory) elif reformulator_name in {'ntp'}: res = NTPReformulator(nb_hops=nb_hops, embedding_size=embedding_size, kernel=kernel, init_name=ref_init_type) assert res is not None return res, is_reversed hops_lst = [make_hop(s) for s in hops_str] hoppy = Hoppy(model=model, k=k_max, depth=max_depth, tnorm_name=tnorm_name, hops_lst=hops_lst, R=gntp_R).to(device) def scoring_function(story: List[Fact], targets: List[Fact]) -> Tensor: story_rel = encode_relation(story, relation_embeddings.weight, relation_to_idx, device) story_arg1, story_arg2 = encode_arguments(story, entity_embeddings.weight, entity_to_idx, device) targets_rel = encode_relation(targets, relation_embeddings.weight, relation_to_idx, device) targets_arg1, targets_arg2 = encode_arguments(targets, entity_embeddings.weight, entity_to_idx, device) embeddings = encode_entities(story, entity_embeddings.weight, entity_to_idx, device) facts = [story_rel, story_arg1, story_arg2] max_depth_ = hoppy.depth if test_max_depth is not None: hoppy.depth = test_max_depth scores = hoppy.score(targets_rel, targets_arg1, targets_arg2, facts, embeddings) if test_max_depth is not None: hoppy.depth = max_depth_ return scores def evaluate(instances: List[Instance], path: str, sample_size: Optional[int] = None) -> float: res = 0.0 if len(instances) > 0: res = accuracy(scoring_function=scoring_function, instances=instances, sample_size=sample_size, relation_to_predicate=relation_to_predicate, predicate_to_relations=predicate_to_relations) logger.info(f'Test Accuracy on {path}: {res:.6f}') return res loss_function = nn.BCELoss() N2_reg = N2() if N2_weight is not None else None N3_reg = N3() if N3_weight is not None else None entropy_reg = Entropy(use_logits=False) if entropy_weight is not None else None params_lst = [p for p in hoppy.parameters() if not torch.equal(p, entity_embeddings.weight)] params_lst += relation_embeddings.parameters() params = nn.ParameterList(params_lst).to(device) if load_path is not None: model.load_state_dict(torch.load(load_path)) for tensor in params_lst: logger.info(f'\t{tensor.size()}\t{tensor.device}') optimizer_factory = { 'adagrad': lambda arg: optim.Adagrad(arg, lr=learning_rate), 'adam': lambda arg: optim.Adam(arg, lr=learning_rate), 'sgd': lambda arg: optim.SGD(arg, lr=learning_rate) } assert optimizer_name in optimizer_factory optimizer = optimizer_factory[optimizer_name](params) global_step = 0 for epoch_no in range(1, nb_epochs + 1): batcher = Batcher(batch_size=batch_size, nb_examples=nb_examples, nb_epochs=1, random_state=random_state) nb_batches = len(batcher.batches) epoch_loss_values = [] for batch_no, (batch_start, batch_end) in enumerate(batcher.batches, start=1): global_step += 1 indices_batch = batcher.get_batch(batch_start, batch_end) instances_batch = [data.train[i] for i in indices_batch] batch_loss_values = [] for i, instance in enumerate(instances_batch): story, target = instance.story, instance.target s, r, o = target # if is_debug is True and i == 0: # with torch.no_grad(): # show_rules(model=hoppy, kernel=kernel, relation_embeddings=relation_embeddings, # data=data, relation_to_idx=relation_to_idx, device=device) story_rel = encode_relation(story, relation_embeddings.weight, relation_to_idx, device) story_arg1, story_arg2 = encode_arguments(story, entity_embeddings.weight, entity_to_idx, device) embeddings = encode_entities(story, entity_embeddings.weight, entity_to_idx, device) facts = [story_rel, story_arg1, story_arg2] # print('E', embeddings.weight.shape, 'S', story_rel.shape) pos_predicate = relation_to_predicate[r] p_relation_lst = sorted(relation_to_predicate.keys()) target_lst = [(s, x, o) for x in p_relation_lst] label_lst = [int(pos_predicate == relation_to_predicate[r]) for r in p_relation_lst] rel_emb = encode_relation(target_lst, relation_embeddings.weight, relation_to_idx, device) arg1_emb, arg2_emb = encode_arguments(target_lst, entity_embeddings.weight, entity_to_idx, device) scores = hoppy.score(rel_emb, arg1_emb, arg2_emb, facts, embeddings) labels = torch.tensor(label_lst, dtype=torch.float32).to(device) loss = loss_function(scores, labels) factors = [hoppy.factor(e) for e in [rel_emb, arg1_emb, arg2_emb]] loss += N2_weight * N2_reg(factors) if N2_weight is not None else 0.0 loss += N3_weight * N3_reg(factors) if N3_weight is not None else 0.0 if entropy_weight is not None: # attention = relation_embeddings.attention for hop, _ in hops_lst: attn_logits = hop.projection(rel_emb) attention = torch.softmax(attn_logits, dim=1) loss += entropy_weight * entropy_reg([attention]) loss_value = loss.item() batch_loss_values += [loss_value] epoch_loss_values += [loss_value] loss.backward() optimizer.step() optimizer.zero_grad() loss_mean, loss_std = np.mean(batch_loss_values), np.std(batch_loss_values) logger.info(f'Epoch {epoch_no}/{nb_epochs}\tBatch {batch_no}/{nb_batches}\tLoss {loss_mean:.4f} ± {loss_std:.4f}') if global_step % evaluate_every == 0: for test_path in test_paths: instances = data.test[test_path] evaluate(instances=instances, path=test_path) if is_debug is True: with torch.no_grad(): show_rules(model=hoppy, kernel=kernel, relation_embeddings=relation_embeddings, data=data, relation_to_idx=relation_to_idx, device=device) loss_mean, loss_std = np.mean(epoch_loss_values), np.std(epoch_loss_values) slope = kernel.slope.item() if isinstance(kernel.slope, Tensor) else kernel.slope logger.info(f'Epoch {epoch_no}/{nb_epochs}\tLoss {loss_mean:.4f} ± {loss_std:.4f}\tSlope {slope:.4f}') import time start = time.time() for test_path in test_paths: evaluate(instances=data.test[test_path], path=test_path) end = time.time() logger.info(f'Evaluation took {end - start} seconds.') if save_path is not None: torch.save(model.state_dict(), save_path) logger.info("Training finished")