def main(_): train_base.run( master=FLAGS.master, input_dataset_class=tf.data.TFRecordDataset, common_module=common, keypoint_profiles_module=keypoint_profiles, models_module=models, input_example_parser_creator=None, keypoint_preprocessor_3d=input_generator.preprocess_keypoints_3d, keypoint_distance_config_override={}, create_model_input_fn_kwargs={}, embedder_fn_kwargs={})
def main(_): input_example_parser_creator = functools.partial( tfse_input_layer.create_tfse_parser, sequence_length=FLAGS.input_sequence_length) create_model_input_fn_kwargs = {'sequential_inputs': True} train_base.run( master=FLAGS.master, input_dataset_class=tf.data.TFRecordDataset, common_module=common, keypoint_profiles_module=keypoint_profiles, models_module=models, input_example_parser_creator=input_example_parser_creator, keypoint_preprocessor_3d=input_generator.preprocess_keypoints_3d, keypoint_distance_config_override=( _create_keypoint_distance_config_override()), create_model_input_fn_kwargs=create_model_input_fn_kwargs, embedder_fn_kwargs={ 'num_late_fusion_preprojection_nodes': FLAGS.num_late_fusion_preprojection_nodes, 'late_fusion_preprojection_activation_fn': FLAGS.late_fusion_preprojection_activation_fn, })