Esempio n. 1
0
class DBOWTest(TestCase):

    def setUp(self):
        self.batch_size = 2
        self.num_noise_words = 2
        self.num_docs = 3
        self.num_words = 15
        self.vec_dim = 10

        self.doc_ids = torch.LongTensor([1, 2])
        self.target_noise_ids = torch.LongTensor([[1, 3, 4], [2, 4, 7]])
        self.model = DBOW(
            self.vec_dim, self.num_docs, self.num_words)

    def test_num_parameters(self):
        self.assertEqual(
            sum([x.size()[0] * x.size()[1] for x in self.model.parameters()]),
            self.num_docs * self.vec_dim + self.num_words * self.vec_dim)

    def test_forward(self):
        x = self.model.forward(self.doc_ids, self.target_noise_ids)

        self.assertEqual(x.size()[0], self.batch_size)
        self.assertEqual(x.size()[1], self.num_noise_words + 1)

    def test_backward(self):
        cost_func = NegativeSampling()
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001)
        for _ in range(2):
            x = self.model.forward(self.doc_ids, self.target_noise_ids)
            x = cost_func.forward(x)
            self.model.zero_grad()
            x.backward()
            optimizer.step()

        self.assertEqual(torch.sum(self.model._D.grad[0, :].data), 0)
        self.assertNotEqual(torch.sum(self.model._D.grad[1, :].data), 0)
        self.assertNotEqual(torch.sum(self.model._D.grad[2, :].data), 0)

        target_noise_ids = self.target_noise_ids.numpy().flatten()

        for word_id in range(15):
            if word_id in target_noise_ids:
                self.assertNotEqual(
                    torch.sum(self.model._O.grad[:, word_id].data), 0)
            else:
                self.assertEqual(
                    torch.sum(self.model._O.grad[:, word_id].data), 0)
def _run(data_file_name, dataset, data_generator, num_batches, vocabulary_size,
         context_size, num_noise_words, vec_dim, num_epochs, batch_size, lr,
         model_ver, vec_combine_method, save_all, generate_plot,
         model_ver_is_dbow):

    if model_ver_is_dbow:
        model = DBOW(vec_dim, num_docs=len(dataset), num_words=vocabulary_size)
    else:
        model = DM(vec_dim, num_docs=len(dataset), num_words=vocabulary_size)

    cost_func = NegativeSampling()
    optimizer = Adam(params=model.parameters(), lr=lr)

    if torch.cuda.is_available():
        model.cuda()

    print("Dataset comprised of {:d} documents.".format(len(dataset)))
    print("Vocabulary size is {:d}.\n".format(vocabulary_size))
    print("Training started.")

    best_loss = float("inf")
    prev_model_file_path = None

    for epoch_i in range(num_epochs):
        epoch_start_time = time.time()
        loss = []

        for batch_i in range(num_batches):
            batch = next(data_generator)
            if torch.cuda.is_available():
                batch.cuda_()

            if model_ver_is_dbow:
                x = model.forward(batch.doc_ids, batch.target_noise_ids)
            else:
                x = model.forward(batch.context_ids, batch.doc_ids,
                                  batch.target_noise_ids)

            x = cost_func.forward(x)

            loss.append(x.item())
            model.zero_grad()
            x.backward()
            optimizer.step()
            _print_progress(epoch_i, batch_i, num_batches)

        # end of epoch
        loss = torch.mean(torch.FloatTensor(loss))
        is_best_loss = loss < best_loss
        best_loss = min(loss, best_loss)

        state = {
            'epoch': epoch_i + 1,
            'model_state_dict': model.state_dict(),
            'best_loss': best_loss,
            'optimizer_state_dict': optimizer.state_dict()
        }

        prev_model_file_path = save_training_state(
            data_file_name, model_ver, vec_combine_method, context_size,
            num_noise_words, vec_dim, batch_size, lr, epoch_i, loss, state,
            save_all, generate_plot, is_best_loss, prev_model_file_path,
            model_ver_is_dbow)

        epoch_total_time = round(time.time() - epoch_start_time)
        print(" ({:d}s) - loss: {:.4f}".format(epoch_total_time, loss))