Exemple #1
0
    def test_merge_dict(self):
        target_dict = {'a': 1, 'b': 2}
        source_dict = {'c': 4}
        data_utils.merge_dict(source_dict, target_dict)
        self.assertDictEqual(target_dict, {'a': 1, 'b': 2, 'c': 4})

        target_dict = {'a': 1, 'b': 2}
        source_dict = {'b': 3, 'c': 4}
        with self.assertRaisesRegexp(ValueError, 'Key conflict: `b`.'):
            data_utils.merge_dict(source_dict, target_dict)
Exemple #2
0
            def create_inputs():
                """Creates pipeline and model inputs."""
                inputs = pipeline_utils.read_batch_from_dataset_tables(
                    FLAGS.input_table,
                    batch_sizes=[int(x) for x in FLAGS.batch_size],
                    num_instances_per_record=2,
                    shuffle=True,
                    num_epochs=None,
                    keypoint_names_3d=configs['keypoint_profile_3d'].
                    keypoint_names,
                    keypoint_names_2d=configs['keypoint_profile_2d'].
                    keypoint_names,
                    min_keypoint_score_2d=FLAGS.min_input_keypoint_score_2d,
                    shuffle_buffer_size=FLAGS.input_shuffle_buffer_size,
                    common_module=common_module,
                    dataset_class=input_dataset_class,
                    input_example_parser_creator=input_example_parser_creator)

                (inputs[common_module.KEY_KEYPOINTS_3D],
                 keypoint_preprocessor_side_outputs_3d
                 ) = keypoint_preprocessor_3d(
                     inputs[common_module.KEY_KEYPOINTS_3D],
                     keypoint_profile_3d=configs['keypoint_profile_3d'],
                     normalize_keypoints_3d=True)
                inputs.update(keypoint_preprocessor_side_outputs_3d)

                inputs['model_inputs'], side_inputs = configs[
                    'create_model_input_fn'](
                        inputs[common_module.KEY_KEYPOINTS_2D],
                        inputs[common_module.KEY_KEYPOINT_MASKS_2D],
                        inputs[common_module.KEY_PREPROCESSED_KEYPOINTS_3D],
                        model_input_keypoint_type=FLAGS.
                        model_input_keypoint_type,
                        normalize_keypoints_2d=True,
                        keypoint_profile_2d=configs['keypoint_profile_2d'],
                        keypoint_profile_3d=configs['keypoint_profile_3d'],
                        azimuth_range=configs[
                            'random_projection_azimuth_range'],
                        elevation_range=configs[
                            'random_projection_elevation_range'],
                        roll_range=configs['random_projection_roll_range'],
                        normalized_camera_depth_range=(
                            configs['random_projection_camera_depth_range']))
                data_utils.merge_dict(side_inputs, inputs)
                return inputs