Пример #1
0
    def test_collect_errors_and_input_loads_from_database(self):
        def mock_deserialize(s_result):
            mock_result = mock.Mock()
            mock_result.identifier = s_result['_id']
            mock_result.observations = np.random.normal(0, 1, size=(100, 25))
            return mock_result

        mock_db_client = mock_db_client_fac.create().mock
        mock_db_client.results_collection = mock.Mock()
        mock_db_client.results_collection.find.side_effect = \
            lambda query: [{'_id': oid} for oid in query['_id']['$in']]
        mock_db_client.deserialize_entity.side_effect = mock_deserialize

        cached_ids = {bson.ObjectId() for _ in range(5)}
        unloaded_ids = {bson.ObjectId() for _ in range(5)}
        results_cache = {
            result_id: np.array([[2500 * k + 25 * j + i for i in range(25)]
                                 for j in range(100)])
            for k, result_id in enumerate(cached_ids)
        }
        gprwe.collect_errors_and_input(cached_ids | unloaded_ids,
                                       mock_db_client, results_cache)
        self.assertTrue(mock_db_client.results_collection.find.called)
        self.assertEqual({'_id': {
            '$in': list(unloaded_ids)
        }}, mock_db_client.results_collection.find.call_args[0][0])
        for unloaded_id in unloaded_ids:
            self.assertIn(unloaded_id, results_cache)
Пример #2
0
    def test_predict_real_and_virtual_errors(self):
        validation_results = [bson.ObjectId() for _ in range(10)]
        real_world_results = [bson.ObjectId() for _ in range(10)]
        virtual_results_by_quality = {
            'quality_{0}'.format(idx):
            ({bson.ObjectId()
              for _ in range(10)}, {bson.ObjectId()
                                    for _ in range(10)})
            for idx in range(4)
        }

        # Build a results cache so we don't have to do anything for the db client
        results_ids = set(validation_results) | set(real_world_results)
        for result_set_1, result_set_2 in virtual_results_by_quality.values():
            results_ids |= set(result_set_1) | set(result_set_2)
        predictors = create_test_predictors()
        errors = create_all_errors(predictors)
        obs_size = len(predictors) // len(results_ids)
        results_cache = {
            result_id: np.hstack(
                (errors[obs_size * ridx:obs_size * (ridx + 1), :],
                 predictors[obs_size * ridx:obs_size * (ridx + 1), :]))
            for ridx, result_id in enumerate(results_ids)
        }

        # predict the errors
        start = time.time()
        real_world_scores, errors_by_group = gprwe.predict_real_and_virtual_errors(
            validation_results=validation_results,
            real_world_results=real_world_results,
            virtual_results_by_quality=virtual_results_by_quality,
            db_client=mock_db_client_fac.create().mock,
            results_cache=results_cache)
        end = time.time()

        print("predict real and virtual errors time: {0}".format(end - start))
        self.assertLess(end - start, 200)
        self.assertEqual(13, len(real_world_scores))
        for result_list in real_world_scores:
            self.assertEqual(obs_size * len(validation_results),
                             len(result_list))
        for quality in virtual_results_by_quality.keys():
            for group in {
                    ' all data', ' no validation trajectory',
                    ' only validation trajectory'
            }:
                self.assertIn(quality + group, errors_by_group)
                self.assertEqual(13, len(errors_by_group[quality + group]))
                for result_list in errors_by_group[quality + group]:
                    self.assertEqual(
                        obs_size * len(validation_results), len(result_list),
                        "wrong number of results for group {0}".format(
                            quality + group))
Пример #3
0
 def test_collect_errors_and_input_loads_copies_from_cache(self):
     result_id = bson.ObjectId()
     mock_db_client = mock_db_client_fac.create().mock
     results_cache = {
         result_id:
         np.array([[25 * j + i for i in range(25)] for j in range(100)])
     }
     x, y = gprwe.collect_errors_and_input({result_id}, mock_db_client,
                                           results_cache)
     x[:, 0] = -10
     y[:, 0] = -20
     self.assertEqual(0, results_cache[result_id][0, 0])
     self.assertEqual(13, results_cache[result_id][0, 13])
     self.assertTrue(np.all(results_cache[result_id] >= 0))
Пример #4
0
 def test_collect_errors_and_input_loads_from_cache(self):
     result_ids = [bson.ObjectId() for _ in range(5)]
     mock_db_client = mock_db_client_fac.create().mock
     results_cache = {
         result_id: np.array([[2500 * k + 25 * j + i for i in range(25)]
                              for j in range(100)])
         for k, result_id in enumerate(result_ids)
     }
     x, y = gprwe.collect_errors_and_input(result_ids, mock_db_client,
                                           results_cache)
     self.assertEqual(500, x.shape[0])
     self.assertEqual(12, x.shape[1])
     self.assertEqual(500, y.shape[0])
     self.assertEqual(13, y.shape[1])
Пример #5
0
    def test_analyse_distributions(self):
        def mock_deserialize(s_result):
            mock_result = mock.Mock()
            mock_result.identifier = s_result['_id']
            mock_result.observations = np.random.normal(0, 1, size=(1000, 25))
            mock_result.observations[:, 6] = np.array(
                [np.nan for _ in range(1000)])
            return mock_result

        mock_db_client = mock_db_client_fac.create().mock
        mock_db_client.results_collection = mock.Mock()
        mock_db_client.results_collection.find.side_effect = \
            lambda query: [{'_id': oid} for oid in query['_id']['$in']]
        mock_db_client.deserialize_entity.side_effect = mock_deserialize
        system_name = 'test_system'
        system_id = bson.ObjectId()
        output_folder = 'temp-test-analyse-distributions'

        subject = gprwe.GeneratedPredictRealWorldExperiment(
            systems={system_name: system_id},
            simulators={
                'Block World': bson.ObjectId(),
                'Block World 2': bson.ObjectId()
            },
            trajectory_groups={
                'KITTI trajectory 1':
                tg.TrajectoryGroup(
                    name='KITTI trajectory 1',
                    reference_id=bson.ObjectId(),
                    mappings=[('Block World', {
                        'location': [12, -63.2, 291.1],
                        'rotation': [-22, -214, 121]
                    })],
                    baseline_configuration={'test': bson.ObjectId()},
                    controller_id=bson.ObjectId(),
                    generated_datasets={
                        'Block World': {
                            'max quality': bson.ObjectId(),
                            'min quality': bson.ObjectId()
                        },
                        'Block World 2': {
                            'max quality': bson.ObjectId(),
                            'min quality': bson.ObjectId()
                        }
                    }),
                'KITTI trajectory 2':
                tg.TrajectoryGroup(
                    name='KITTI trajectory 2',
                    reference_id=bson.ObjectId(),
                    mappings=[('Block World', {
                        'location': [12, -63.2, 291.1],
                        'rotation': [-22, -214, 121]
                    })],
                    baseline_configuration={'test': bson.ObjectId()},
                    controller_id=bson.ObjectId(),
                    generated_datasets={
                        'Block World': {
                            'max quality': bson.ObjectId(),
                            'min quality': bson.ObjectId()
                        },
                        'Block World 2': {
                            'max quality': bson.ObjectId(),
                            'min quality': bson.ObjectId()
                        }
                    })
            },
            benchmarks={'Estimate Errors': bson.ObjectId()},
            enabled=True)
        for system_id in subject.systems.values():
            for traj_group in subject.trajectory_groups.values():
                for dataset_id in traj_group.get_all_dataset_ids():
                    subject.store_trial_results(
                        system_id=system_id,
                        image_source_id=dataset_id,
                        trial_result_ids=[bson.ObjectId() for _ in range(3)],
                        db_client=mock_db_client)
                    for benchmark_id in subject.benchmarks.values():
                        subject.store_benchmark_result(
                            system_id=system_id,
                            image_source_id=dataset_id,
                            benchmark_id=benchmark_id,
                            benchmark_result_id=bson.ObjectId())

        results_cache = {}
        #with mock.patch('generated_predict_rw_experiment.create_distribution_plots'):
        subject.analyse_distributions(system_name=system_name,
                                      output_folder=output_folder,
                                      db_client=mock_db_client,
                                      results_cache=results_cache)