def setUp(self): self.batch_size = 2 self.num_noise_words = 2 self.num_docs = 3 self.num_words = 15 self.vec_dim = 10 self.doc_ids = torch.LongTensor([1, 2]) self.target_noise_ids = torch.LongTensor([[1, 3, 4], [2, 4, 7]]) self.model = DBOW( self.vec_dim, self.num_docs, self.num_words)
def _load_model(model_file_name, vec_dim, num_docs, num_words): model_ver = re.search('_model\.(dm|dbow)', model_file_name).group(1) if model_ver is None: raise ValueError("Model file name contains an invalid" "version of the model") model_file_path = join(MODELS_DIR, model_file_name) try: checkpoint = torch.load(model_file_path) except AssertionError: checkpoint = torch.load(model_file_path, map_location=lambda storage, location: storage) if model_ver == 'dbow': model = DBOW(vec_dim, num_docs, num_words) else: model = DM(vec_dim, num_docs, num_words) model.load_state_dict(checkpoint['model_state_dict']) return model
class DBOWTest(TestCase): def setUp(self): self.batch_size = 2 self.num_noise_words = 2 self.num_docs = 3 self.num_words = 15 self.vec_dim = 10 self.doc_ids = torch.LongTensor([1, 2]) self.target_noise_ids = torch.LongTensor([[1, 3, 4], [2, 4, 7]]) self.model = DBOW( self.vec_dim, self.num_docs, self.num_words) def test_num_parameters(self): self.assertEqual( sum([x.size()[0] * x.size()[1] for x in self.model.parameters()]), self.num_docs * self.vec_dim + self.num_words * self.vec_dim) def test_forward(self): x = self.model.forward(self.doc_ids, self.target_noise_ids) self.assertEqual(x.size()[0], self.batch_size) self.assertEqual(x.size()[1], self.num_noise_words + 1) def test_backward(self): cost_func = NegativeSampling() optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001) for _ in range(2): x = self.model.forward(self.doc_ids, self.target_noise_ids) x = cost_func.forward(x) self.model.zero_grad() x.backward() optimizer.step() self.assertEqual(torch.sum(self.model._D.grad[0, :].data), 0) self.assertNotEqual(torch.sum(self.model._D.grad[1, :].data), 0) self.assertNotEqual(torch.sum(self.model._D.grad[2, :].data), 0) target_noise_ids = self.target_noise_ids.numpy().flatten() for word_id in range(15): if word_id in target_noise_ids: self.assertNotEqual( torch.sum(self.model._O.grad[:, word_id].data), 0) else: self.assertEqual( torch.sum(self.model._O.grad[:, word_id].data), 0)
def _run(data_file_name, dataset, data_generator, num_batches, vocabulary_size, context_size, num_noise_words, vec_dim, num_epochs, batch_size, lr, model_ver, vec_combine_method, save_all, generate_plot, model_ver_is_dbow): if model_ver_is_dbow: model = DBOW(vec_dim, num_docs=len(dataset), num_words=vocabulary_size) else: model = DM(vec_dim, num_docs=len(dataset), num_words=vocabulary_size) cost_func = NegativeSampling() optimizer = Adam(params=model.parameters(), lr=lr) if torch.cuda.is_available(): model.cuda() print("Dataset comprised of {:d} documents.".format(len(dataset))) print("Vocabulary size is {:d}.\n".format(vocabulary_size)) print("Training started.") best_loss = float("inf") prev_model_file_path = None for epoch_i in range(num_epochs): epoch_start_time = time.time() loss = [] for batch_i in range(num_batches): batch = next(data_generator) if torch.cuda.is_available(): batch.cuda_() if model_ver_is_dbow: x = model.forward(batch.doc_ids, batch.target_noise_ids) else: x = model.forward(batch.context_ids, batch.doc_ids, batch.target_noise_ids) x = cost_func.forward(x) loss.append(x.item()) model.zero_grad() x.backward() optimizer.step() _print_progress(epoch_i, batch_i, num_batches) # end of epoch loss = torch.mean(torch.FloatTensor(loss)) is_best_loss = loss < best_loss best_loss = min(loss, best_loss) state = { 'epoch': epoch_i + 1, 'model_state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer_state_dict': optimizer.state_dict() } prev_model_file_path = save_training_state( data_file_name, model_ver, vec_combine_method, context_size, num_noise_words, vec_dim, batch_size, lr, epoch_i, loss, state, save_all, generate_plot, is_best_loss, prev_model_file_path, model_ver_is_dbow) epoch_total_time = round(time.time() - epoch_start_time) print(" ({:d}s) - loss: {:.4f}".format(epoch_total_time, loss))