def test_read_batch_from_dataset_tables(self):
        testdata_dir = 'poem/testdata'  # Assume $PWD == "google_research/".
        table_path = os.path.join(FLAGS.test_srcdir, testdata_dir,
                                  'tfe-2.tfrecords')
        inputs = pipeline_utils.read_batch_from_dataset_tables(
            [table_path, table_path],
            batch_sizes=[4, 2],
            num_instances_per_record=2,
            shuffle=True,
            num_epochs=None,
            keypoint_names_3d=keypoint_profiles.create_keypoint_profile_or_die(
                'LEGACY_3DH36M17').keypoint_names,
            keypoint_names_2d=keypoint_profiles.create_keypoint_profile_or_die(
                'LEGACY_2DCOCO13').keypoint_names,
            seed=0)

        self.assertCountEqual(inputs.keys(), [
            'image_sizes', 'keypoints_2d', 'keypoint_scores_2d',
            'keypoint_masks_2d', 'keypoints_3d'
        ])
        self.assertEqual(inputs['image_sizes'].shape, [6, 2, 2])
        self.assertEqual(inputs['keypoints_2d'].shape, [6, 2, 13, 2])
        self.assertEqual(inputs['keypoint_scores_2d'].shape, [6, 2, 13])
        self.assertEqual(inputs['keypoint_masks_2d'].shape, [6, 2, 13])
        self.assertEqual(inputs['keypoints_3d'].shape, [6, 2, 17, 3])
Ejemplo n.º 2
0
            def create_inputs():
                """Creates pipeline and model inputs."""
                inputs = pipeline_utils.read_batch_from_dataset_tables(
                    FLAGS.input_table,
                    batch_sizes=[int(x) for x in FLAGS.batch_size],
                    num_instances_per_record=2,
                    shuffle=True,
                    num_epochs=None,
                    keypoint_names_3d=configs['keypoint_profile_3d'].
                    keypoint_names,
                    keypoint_names_2d=configs['keypoint_profile_2d'].
                    keypoint_names,
                    min_keypoint_score_2d=FLAGS.min_input_keypoint_score_2d,
                    shuffle_buffer_size=FLAGS.input_shuffle_buffer_size,
                    common_module=common_module,
                    dataset_class=input_dataset_class,
                    input_example_parser_creator=input_example_parser_creator)

                (inputs[common_module.KEY_KEYPOINTS_3D],
                 keypoint_preprocessor_side_outputs_3d
                 ) = keypoint_preprocessor_3d(
                     inputs[common_module.KEY_KEYPOINTS_3D],
                     keypoint_profile_3d=configs['keypoint_profile_3d'],
                     normalize_keypoints_3d=True)
                inputs.update(keypoint_preprocessor_side_outputs_3d)

                inputs['model_inputs'], side_inputs = configs[
                    'create_model_input_fn'](
                        inputs[common_module.KEY_KEYPOINTS_2D],
                        inputs[common_module.KEY_KEYPOINT_MASKS_2D],
                        inputs[common_module.KEY_PREPROCESSED_KEYPOINTS_3D],
                        model_input_keypoint_type=FLAGS.
                        model_input_keypoint_type,
                        normalize_keypoints_2d=True,
                        keypoint_profile_2d=configs['keypoint_profile_2d'],
                        keypoint_profile_3d=configs['keypoint_profile_3d'],
                        azimuth_range=configs[
                            'random_projection_azimuth_range'],
                        elevation_range=configs[
                            'random_projection_elevation_range'],
                        roll_range=configs['random_projection_roll_range'],
                        normalized_camera_depth_range=(
                            configs['random_projection_camera_depth_range']))
                data_utils.merge_dict(side_inputs, inputs)
                return inputs