def evaluate(self, split="test"):
        """
        Get the loss and accuracy on the test/train dataset
        """
        self.model.eval()

        num_examples = 0
        num_correct = 0
        loss = 0

        for _, batch in enumerate(
                self.test_loader if split is "test" else self.train_loader):
            if self.cuda:
                input_data, target_data = Variable(batch[0]).cuda(), Variable(
                    batch[1]).cuda()
            else:
                input_data, target_data = Variable(batch[0]), Variable(
                    batch[1])

            output_data = self.model(input_data)

            num_examples += input_data.shape[0]
            loss += float(
                compute_loss(self.model,
                             output_data,
                             target_data,
                             is_normalize=False))
            predicted_labels = predict_labels(output_data)
            num_correct += torch.sum(
                predicted_labels == target_data).cpu().item()

        self.model.train()

        return loss / float(num_examples), float(num_correct) / float(
            num_examples)
示例#2
0
    def get_accuracy(self, split='test'):
        '''
    Get the accuracy on the test/train dataset
    '''
        self.model.eval()

        num_examples = 0
        num_correct = 0
        for batch_idx, batch in enumerate(
                self.test_loader if split is 'test' else self.train_loader):
            if self.cuda:
                input_data, target_data = Variable(batch[0]).cuda(), Variable(
                    batch[1]).cuda()
            else:
                input_data, target_data = Variable(batch[0]), Variable(
                    batch[1])

            num_examples += input_data.shape[0]
            predicted_labels = predict_labels(self.model, input_data)
            num_correct += torch.sum(
                predicted_labels == target_data).cpu().item()

        self.model.train()

        return float(num_correct) / float(num_examples)
示例#3
0
def test_predict_labels():
    '''
  Test the label prediction logic on a dummy net
  '''

    test_net = TestModel()

    x = torch.FloatTensor([+1.4, -1.4, -0.7, 2.3, 0.3]).reshape(1, -1)

    labels = predict_labels(test_net, x)
    assert labels.item() == 4