def train(self): data_manager = DataManager(self.batch_size, logger=self.logger, is_many_to_one=self.is_many_to_one, data_file_count=self.data_file_count, pretrained_file=self.pre_train) if self.is_many_to_one: net = RNN_M2O(len(data_manager.word_list), self.embedding_len, self.hidden_size, self.learning_rate, self.num_hidden_layer, self.drop_rate, use_adam=True, use_cuda=self.use_cuda, pretrained_emb=data_manager.pretrained_embeddings()) else: net = RNN_M2M(len(data_manager.word_list), self.embedding_len, self.hidden_size, self.learning_rate, self.num_hidden_layer, self.drop_rate, use_adam=True, use_cuda=self.use_cuda, pretrained_emb=data_manager.pretrained_embeddings()) self._train(net, data_manager)
def test(self, id): _, lr, hs, nh = re.search(r'M2(M|O)_([0-9]+)_([0-9]+)_([0-9]+)_?', id).groups() lr, hs, nh = float('0.'+lr[1:]), int(hs), int(nh) data_manager = DataManager(self.batch_size, logger=self.logger, is_many_to_one=self.is_many_to_one, data_file_count=self.data_file_count, pretrained_file=self.pre_train, is_test=True) if self.is_many_to_one: model = RNN_M2O else: model = RNN_M2M net = model(len(data_manager.word_list), self.embedding_len, hs, lr, nh, self.drop_rate, use_adam=True, use_cuda=self.use_cuda, pretrained_emb=data_manager.pretrained_embeddings()) status, _epoch_index, _perplexity_history, _min_perplexity = self._load(net, id) if status: loss_fn = net.get_loss() # Testing test_losses = 0. test_acc = 0. test_counter = 0 net.eval() for data, label in data_manager.test_loader(): data = T.autograd.Variable(T.LongTensor(data)) label = T.autograd.Variable(T.LongTensor(label)) if self.use_cuda: data = data.cuda() label = label.cuda() output, predicted = net(data) test_losses += loss_fn(output.view(-1, len(data_manager.word_list)), label.view(-1)) \ .data.cpu()[0] * data.size(0) test_acc += (label.squeeze() == predicted).float().mean().data * data.size(0) test_counter += data.size(0) mean_test_loss = test_losses/test_counter mean_test_acc = test_acc/test_counter perplexity = np.exp(mean_test_loss) self.logger.i('Loss: %.4f, Acc: %.4f, Perp: %.4f'%(mean_test_loss, mean_test_acc, perplexity)) return mean_test_loss, mean_test_acc, perplexity else: raise AssertionError('Model file not found!')