Beispiel #1
0
    def test_conv_e_predict(self):
        """Test ConvE's predict function."""
        conv_e = ConvE(config=CONV_E_CONFIG)

        predictions = conv_e.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), int)
Beispiel #2
0
    def test_conv_e_predict(self):
        """Test ConvE's predict function."""
        conv_e = ConvE(**CONV_E_CONFIG)
        conv_e.num_entities = CONV_E_CONFIG[NUM_ENTITIES]
        conv_e.num_relations = CONV_E_CONFIG[NUM_RELATIONS]
        predictions = conv_e.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), int)
Beispiel #3
0
 def test_instantiate_conv_e(self):
     """Test that ConvE can be instantiated."""
     conv_e = ConvE(**CONV_E_CONFIG)
     conv_e.num_entities = CONV_E_CONFIG[NUM_ENTITIES]
     conv_e.num_relations = CONV_E_CONFIG[NUM_RELATIONS]
     self.assertIsNotNone(conv_e)
     self.assertEqual(conv_e.num_entities, 5)
     self.assertEqual(conv_e.num_relations, 5)
     self.assertEqual(conv_e.embedding_dim, 5)
Beispiel #4
0
 def test_instantiate_conv_e(self):
     """Test that ConvE can be instantiated."""
     conv_e = ConvE(config=CONV_E_CONFIG)
     self.assertIsNotNone(conv_e)
     self.assertEqual(conv_e.num_entities, 5)
     self.assertEqual(conv_e.num_relations, 5)
     self.assertEqual(conv_e.embedding_dim, 5)
Beispiel #5
0
def _train_conv_e_model(
    kge_model: ConvE,
    all_entities,
    learning_rate,
    num_epochs,
    batch_size,
    pos_triples,
    device,
    seed: Optional[int] = None,
    tqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> Tuple[ConvE, List[float]]:
    """"""
    if seed is not None:
        np.random.seed(seed=seed)

    kge_model = kge_model.to(device)

    optimizer = optim.Adam(kge_model.parameters(), lr=learning_rate)

    loss_per_epoch: List[float] = []

    log.info('****Run Model On %s****' % str(device).upper())

    num_pos_triples = pos_triples.shape[0]
    num_entities = all_entities.shape[0]

    start_training = timeit.default_timer()

    _tqdm_kwargs = dict(desc='Training epoch')
    if tqdm_kwargs:
        _tqdm_kwargs.update(tqdm_kwargs)

    for epoch in trange(num_epochs, **_tqdm_kwargs):
        indices = np.arange(num_pos_triples)
        np.random.shuffle(indices)
        pos_triples = pos_triples[indices]
        num_positives = batch_size // 2
        # TODO: Make sure that batch = num_pos + num_negs
        # num_negatives = batch_size - num_positives

        pos_batches = _split_list_in_batches(input_list=pos_triples,
                                             batch_size=num_positives)
        current_epoch_loss = 0.

        for i in range(len(pos_batches)):
            # TODO: Remove original subject and object from entity set
            pos_batch = pos_batches[i]
            current_batch_size = len(pos_batch)
            batch_subjs = pos_batch[:, 0:1]
            batch_relations = pos_batch[:, 1:2]
            batch_objs = pos_batch[:, 2:3]

            num_subj_corrupt = len(pos_batch) // 2
            num_obj_corrupt = len(pos_batch) - num_subj_corrupt
            pos_batch = torch.tensor(pos_batch,
                                     dtype=torch.long,
                                     device=device)

            corrupted_subj_indices = np.random.choice(np.arange(
                0, num_entities),
                                                      size=num_subj_corrupt)
            corrupted_subjects = np.reshape(
                all_entities[corrupted_subj_indices], newshape=(-1, 1))
            subject_based_corrupted_triples = np.concatenate([
                corrupted_subjects, batch_relations[:num_subj_corrupt],
                batch_objs[:num_subj_corrupt]
            ],
                                                             axis=1)

            corrupted_obj_indices = np.random.choice(np.arange(
                0, num_entities),
                                                     size=num_obj_corrupt)
            corrupted_objects = np.reshape(all_entities[corrupted_obj_indices],
                                           newshape=(-1, 1))

            object_based_corrupted_triples = np.concatenate([
                batch_subjs[num_subj_corrupt:],
                batch_relations[num_subj_corrupt:], corrupted_objects
            ],
                                                            axis=1)

            neg_batch = np.concatenate([
                subject_based_corrupted_triples, object_based_corrupted_triples
            ],
                                       axis=0)

            neg_batch = torch.tensor(neg_batch,
                                     dtype=torch.long,
                                     device=device)

            batch = np.concatenate([pos_batch, neg_batch], axis=0)
            positive_labels = np.ones(shape=current_batch_size)
            negative_labels = np.zeros(shape=current_batch_size)
            labels = np.concatenate([positive_labels, negative_labels], axis=0)

            batch, labels = shuffle(batch, labels, random_state=seed)

            batch = torch.tensor(batch, dtype=torch.long, device=device)
            labels = torch.tensor(labels, dtype=torch.float, device=device)

            # Recall that torch *accumulates* gradients. Before passing in a
            # new instance, you need to zero out the gradients from the old
            # instance
            optimizer.zero_grad()
            loss = kge_model(batch, labels)
            current_epoch_loss += (loss.item() * current_batch_size)

            loss.backward()
            optimizer.step()

        # log.info("Epoch %s took %s seconds \n" % (str(epoch), str(round(stop - start))))
        # Track epoch loss
        loss_per_epoch.append(current_epoch_loss / len(pos_triples))

    stop_training = timeit.default_timer()
    log.info("Training took %s seconds \n" %
             (str(round(stop_training - start_training))))

    return kge_model, loss_per_epoch