예제 #1
0
def main(_):
    train_base.run(
        master=FLAGS.master,
        input_dataset_class=tf.data.TFRecordDataset,
        common_module=common,
        keypoint_profiles_module=keypoint_profiles,
        models_module=models,
        input_example_parser_creator=None,
        keypoint_preprocessor_3d=input_generator.preprocess_keypoints_3d,
        keypoint_distance_config_override={},
        create_model_input_fn_kwargs={},
        embedder_fn_kwargs={})
예제 #2
0
def main(_):
    input_example_parser_creator = functools.partial(
        tfse_input_layer.create_tfse_parser,
        sequence_length=FLAGS.input_sequence_length)

    create_model_input_fn_kwargs = {'sequential_inputs': True}

    train_base.run(
        master=FLAGS.master,
        input_dataset_class=tf.data.TFRecordDataset,
        common_module=common,
        keypoint_profiles_module=keypoint_profiles,
        models_module=models,
        input_example_parser_creator=input_example_parser_creator,
        keypoint_preprocessor_3d=input_generator.preprocess_keypoints_3d,
        keypoint_distance_config_override=(
            _create_keypoint_distance_config_override()),
        create_model_input_fn_kwargs=create_model_input_fn_kwargs,
        embedder_fn_kwargs={
            'num_late_fusion_preprojection_nodes':
            FLAGS.num_late_fusion_preprojection_nodes,
            'late_fusion_preprojection_activation_fn':
            FLAGS.late_fusion_preprojection_activation_fn,
        })