コード例 #1
0
def test_data_loader():
    distorter = distorters.NoDistorter()
    dataset_plain = datasets.NormalizedPairedPoseDataset(
        'unit_test/dummy42', distorter, norm.NoNorm, True, dataset_size,
        'cuda:0')
    dataset_dict = datasets.NormalizedPairedPoseDataset(
        'unit_test/dummy42_dict', distorter, norm.NoNorm, True, dataset_size,
        'cuda:0')
    true_sum_sub1 = dataset_dict[:].poses.sum()
    dataset_dict.select_subset('sub2')
    true_sum_sub2 = dataset_dict[:].poses.sum()

    data_loader_plain = datasets.DataLoader(dataset_plain, 6)
    data_loader_dict = datasets.DataLoader(dataset_dict, 6)

    plain_batch = next(iter(data_loader_plain))
    subset_names_plain = data_loader_plain.get_subset_names()
    data_loader_plain.select_subset(subset_names_plain[0])

    all_batches = {}
    sum_of_subsets = {}
    for subset_name in data_loader_dict.get_subset_names():
        data_loader_dict.select_subset(subset_name)
        all_batches[subset_name] = list(data_loader_dict)
        sum_of_subsets[subset_name] = sum(
            batch.poses.sum() for batch in all_batches[subset_name])

    assert type(plain_batch) is datasets.PoseCorrectionBatch
    assert subset_names_plain == ['DEFAULT']
    assert list(all_batches.keys()) == ['sub1', 'sub2']
    assert len(all_batches['sub1']) == 7
    assert type(all_batches['sub1'][0]) == datasets.PoseCorrectionBatch
    assert all_batches['sub1'][0].labels.shape == (6, 21, 3)
    assert torch.allclose(sum_of_subsets['sub1'], true_sum_sub1)
    assert torch.allclose(sum_of_subsets['sub2'], true_sum_sub2)
コード例 #2
0
    def __init__(self,
                 model_directory,
                 hyperparams,
                 dataset,
                 training_config,
                 batch_size=100):
        self.dataset = dataset
        self.data_loader = datasets.DataLoader(self.dataset,
                                               batch_size,
                                               shuffle=False)

        self.training_config = training_config
        self.hyperparams = hyperparams
        self.model = hyperparams['model'](hyperparams['model_args'])
        weights = torch.load(os.path.join(model_directory, 'weights.pt'),
                             map_location='cpu')
        try:
            self.model.load_state_dict(weights)
        except RuntimeError:
            print('WARNING: Unable to load state dict')
        self.model.cuda()

        self.log = torch.load(os.path.join(model_directory, 'log.pt'),
                              map_location='cpu')
        self.logged_results = torch.load(os.path.join(model_directory,
                                                      'eval.pt'),
                                         map_location='cpu')

        self.errors = None
        self.dataset_errors = None
        self.predictions = None
コード例 #3
0
def test_training_session():
    training_set = datasets.NormalizedPairedPoseDataset('unit_test/dummy42',
                                                        distorter,
                                                        norm.NoNorm,
                                                        False,
                                                        None,
                                                        device='cuda:0')
    validation_set = datasets.NormalizedPairedPoseDataset(
        'unit_test/dummy42_dict',
        distorters.NoDistorter(),
        norm.NoNorm,
        True,
        None,
        device='cuda:0')
    training_batch = training_set[:]
    val_loader = datasets.DataLoader(validation_set, batch_size=6)

    training_session = TrainingSession(model, hyperparams, norm.NoNorm)
    training_session.schedule_learning_rate()
    loss, result = training_session.train_batch(training_batch)
    test_results = training_session.test_model(val_loader)

    assert loss.numel() == 1
    assert loss.device == torch.device('cpu')
    assert training_batch.poses.is_same_size(result.poses)
    assert list(test_results.keys()) == ['sub1', 'sub2']
    assert list(test_results['sub1'].keys()) == Evaluator.metric_names
    assert test_results['sub1']['distance'].numel() == 1
コード例 #4
0
def test_to_model():
    distorter = distorters.NoDistorter()
    model = helpers.DummyModel()
    dataset_no_subs = datasets.NormalizedPairedPoseDataset('unit_test/dummy42',
                                                           distorter,
                                                           norm.NoNorm,
                                                           False,
                                                           device='cuda:0')
    dataset_subs = datasets.NormalizedPairedPoseDataset('unit_test/ident42',
                                                        distorter,
                                                        norm.NoNorm,
                                                        True,
                                                        device='cuda:0')
    data_loader_no_subs = datasets.DataLoader(dataset_no_subs, 6)
    data_loader_subs = datasets.DataLoader(dataset_subs, 6)

    batch_size = 42
    true_results = {
        'coord_diff': torch.zeros(batch_size, device='cuda:0'),
        'distance': torch.zeros(batch_size, device='cuda:0'),
        'bone_length': torch.zeros(batch_size, device='cuda:0'),
        'proportion': torch.zeros(batch_size, device='cuda:0'),
    }

    results_norm_no_subs = Evaluator.to_model(data_loader_no_subs, model)
    results_orig_no_subs = Evaluator.to_model(data_loader_no_subs,
                                              model,
                                              space='original')
    results_norm_subs = Evaluator.to_model(data_loader_subs, model)
    results_orig_subs = Evaluator.to_model(data_loader_subs,
                                           model,
                                           space='original')

    for metric_name in Evaluator.metric_names:
        for subset_name in ['sub1', 'sub2']:
            assert torch.allclose(results_norm_subs[subset_name][metric_name],
                                  true_results[metric_name],
                                  atol=1e-5)
            assert torch.allclose(results_orig_subs[subset_name][metric_name],
                                  true_results[metric_name],
                                  atol=1e-5)
        assert torch.allclose(results_norm_no_subs['DEFAULT'][metric_name],
                              true_results[metric_name],
                              atol=1e-5)
        assert torch.allclose(results_orig_no_subs['DEFAULT'][metric_name],
                              true_results[metric_name],
                              atol=1e-5)
コード例 #5
0
    def __init__(self, config, train_set, val_set):
        """
        Initialize a new Solver object.
        :param config: Dictionary with parameters defining the general solver behavior, e.g.
                       verbosity, logging, etc.
        :type config: dict

        :param train_set: Training data set that returns tuples of 2 pose tensors when iterated.
        :type train_set: torch.utils.data.DataSet

        :param val_set: Validation data set that returns tuples of 2 pose tensors when iterated.
        :type train_set: torch.utils.data.DataSet
        """
        if not type(config) is dict:
            raise Exception(
                "Error: The passed config object is not a dictionary.")
        self.config = config

        if train_set.use_preset and train_set.has_subsets:
            shuffle_subsets = True
        else:
            shuffle_subsets = False
        self.train_loader = datasets.DataLoader(
            train_set,
            config['batch_size'],
            shuffle=True,
            shuffle_subsets=shuffle_subsets)
        self.val_loader = datasets.DataLoader(val_set,
                                              min(len(val_set), 10000),
                                              shuffle=False)
        if hasattr(self.train_loader.dataset, 'normalizer'):
            self.normalizer = self.train_loader.dataset.normalizer
        else:
            self.normalizer = None

        self.iters_per_epoch = len(self.train_loader)
コード例 #6
0
"""
Compute the mean evaluation metrics on a dataset (before applying any corrections).
"""

import os
import torch

from data_utils import datasets
from evaluation.evaluator import Evaluator

# Config
dataset_name = 'HANDS17_DPREN_SubjClust_val'
########################################################################

dataset = datasets.PairedPoseDataset(dataset_name, use_preset=True)
data_loader = datasets.DataLoader(dataset, 100000)
results = Evaluator.means_per_metric(Evaluator.to_dataset(data_loader, 'default'))

torch.save(results, os.path.join('results', 'datasets', dataset_name + '.pt'))
コード例 #7
0
ファイル: predict.py プロジェクト: TheFloe1995/correct-pose
import os
import torch

from data_utils import datasets

# Config
dataset_name = 'HANDS17_DPREN_test_poses'
model_path = 'results/hands17sc_gnn_final_best_1/0'
repetition_number = 0
device = torch.device('cuda:0')
output_name = 'HANDS17sc_DPREN_gnn_fb0_test'
batch_size = int(1e5)
########################################################################

dataset = datasets.SinglePoseDataset(dataset_name, device=device)
data_loader = datasets.DataLoader(dataset, batch_size, shuffle=False)

hyperparams = torch.load(os.path.join(model_path, 'params.pt'))
model = hyperparams['model'](hyperparams['model_args'])
model.to(device)

weights = torch.load(os.path.join(model_path, str(repetition_number),
                                  'weights.pt'),
                     map_location=device)
model.load_state_dict(weights)

model.eval()
prediction_list = []
for batch in data_loader:
    prediction_list.append(model(batch.poses).detach().cpu())
predictions = torch.cat(prediction_list, dim=0)