Beispiel #1
0
    def create_logger(params: Parameters) -> "LearningProgressHtmlLogger":
        output_dir = params.creatable_directory("experiment_group_dir")
        experiment_name = params.string("experiment")
        include_links_to_images = params.optional_boolean("include_image_links")
        num_pretty_descriptions = params.positive_integer(
            "num_pretty_descriptions", default=3
        )
        sort_by_length = params.boolean(
            "sort_learner_descriptions_by_length", default=False
        )

        logging_dir = output_dir / experiment_name
        logging_dir.mkdir(parents=True, exist_ok=True)
        output_html_path = str(logging_dir / "index.html")

        if include_links_to_images is None:
            include_links_to_images = False

        logging.info("Experiment will be logged to %s", output_html_path)

        with open(output_html_path, "w") as outfile:
            html_dumper = CurriculumToHtmlDumper()

            outfile.write(f"<head>\n\t<style>{CSS}\n\t</style>\n</head>")
            outfile.write(f"\n<body>\n\t<h1>{experiment_name}</h1>")
            # A JavaScript function to allow toggling perception information
            outfile.write(
                """
                <script>
                function myFunction(id) {
                  var x = document.getElementById(id);
                  if (x.style.display === "none") {
                    x.style.display = "block";
                  } else {
                    x.style.display = "none";
                  }
                }
                </script>
                """
            )
        return LearningProgressHtmlLogger(
            outfile_dir=output_html_path,
            html_dumper=html_dumper,
            include_links_to_images=include_links_to_images,
            num_pretty_descriptions=num_pretty_descriptions,
            sort_by_length=sort_by_length,
        )
Beispiel #2
0
def main(params: Parameters):
    with byte_key_value_source_from_params(params) as input_source:
        keys = list(input_source.keys())
        num_to_sample = min(params.positive_integer(_NUM_TO_SAMPLE_PARAM),
                            len(keys))
        random.shuffle(
            keys,
            random=random.Random(params.integer(_RANDOM_SEED_PARAM,
                                                default=0)).random,
        )
        keys_to_keep = keys[:num_to_sample]
        output_zip_path = params.creatable_file("output_zip_path")
        logging.info("Downsampling %s files to %s", num_to_sample,
                     output_zip_path)
        with KeyValueSink.zip_bytes_sink(output_zip_path) as out:
            for key in keys_to_keep:
                out.put(key, input_source[key])
Beispiel #3
0
def _split_into_even_slices(input_source: KeyValueSource[str, bytes],
                            params: Parameters):
    output_directory = params.creatable_directory("output_dir")
    slices = params.positive_integer("num_slices")
    random_seed = params.optional_positive_integer("random_seed")
    slice_paths = [
        output_directory / "{!s}.zip".format(i) for i in range(slices)
    ]
    CharSink.to_file(output_directory / "_slices.txt").write("\n".join(
        str(x) for x in slice_paths))
    output_sinks = [
        KeyValueSink.zip_bytes_sink(slice_path) for slice_path in slice_paths
    ]
    # this is the magic incantation for handling variable-length lists of context managers
    with ExitStack() as exit_stack:
        for output_sink in output_sinks:
            exit_stack.enter_context(output_sink)
        input_keys = sorted(list(input_source.keys())  # type: ignore
                            )  # guarantee deterministic iteration order
        if random_seed:
            random.seed(random_seed)
            random.shuffle(input_keys)
        for (i, k) in enumerate(input_keys):
            output_sinks[i % slices].put(k, input_source[k])
Beispiel #4
0
def main(
    params: Parameters,
    scenes_iterable_input: Optional[Iterable[Phase1InstanceGroup]] = None,
    output_directory: Optional[Path] = None,
    visualizer: Optional[SituationVisualizer] = None,
) -> None:

    language_mode = params.enum("language_mode",
                                LanguageMode,
                                default=LanguageMode.ENGLISH)

    if scenes_iterable_input is None:
        scenes_iterable: Iterable[Phase1InstanceGroup] = [
            make_curriculum(None, None,
                            phase2_language_generator(language_mode))
        ]
    else:
        scenes_iterable = scenes_iterable_input

    num_iterations = params.positive_integer("iterations")
    steps_before_vis = params.positive_integer("steps_before_vis")

    specific_scene = params.optional_positive_integer("scene")

    automatically_save_renderings = params.boolean(
        "automatically_save_renderings", default=False)

    if "experiment_group_dir" in params:
        rendering_filename_generator = from_experiment_filename_generator
    else:
        rendering_filename_generator = default_filename_generator

    screenshot_dir = output_directory

    random.seed(params.integer("seed"))
    np.random.seed(params.integer("seed"))

    if params.string("debug_bounding_boxes", default="off") == "on":
        debug_bounding_boxes = True
    else:
        debug_bounding_boxes = False

    if params.string("gaze_arrows", default="off") == "on":
        gaze_arrows = True
    else:
        gaze_arrows = False

    # go through curriculum scenes and output geometry types
    if visualizer is None:
        viz = SituationVisualizer()
    else:
        viz = visualizer
        viz.clear_scene()
    model_scales = viz.get_model_scales()
    for object_type, multiplier in OBJECT_SCALE_MULTIPLIER_MAP.items():
        if object_type in model_scales:
            v3 = model_scales[object_type]
            new_v3 = (v3[0] * multiplier, v3[1] * multiplier,
                      v3[2] * multiplier)
            model_scales[object_type] = new_v3
        else:
            model_scales[object_type] = (multiplier, multiplier, multiplier)

    for model_name, scale in model_scales.items():
        logging.info("SCALE: %s -> %s", model_name, scale.__str__())

    # used to start a frame from where the previous one left off
    previous_model_positions: Optional[PositionsMap] = None

    for scene_number, scene_elements in enumerate(
            SceneCreator.create_scenes(scenes_iterable)):
        # If a scene number is provided in the params file, only render that scene
        if specific_scene and scene_number < specific_scene:
            continue
        if specific_scene and scene_number > specific_scene:
            break

        scene_filename = rendering_filename_generator(scene_number,
                                                      scene_elements)
        if scene_filename in _FILENAMES_USED:
            continue
        _FILENAMES_USED.add(scene_filename)

        print(f"SCENE {scene_number}")
        viz.set_title(" ".join(token for token in scene_elements.tokens) +
                      " (" + str(scene_elements.current_frame + 1) + "/" +
                      str(scene_elements.total_frames) + ")")

        # if this is a new scene, forget the positions from the last scene
        if scene_elements.current_frame == 0:
            previous_model_positions = None

        if automatically_save_renderings:
            # if in auto mode and scene contains an excluded vocab word, skip it
            skip_scene = False
            for token in scene_elements.tokens:
                if token in EXCLUDED_VOCAB:
                    skip_scene = True
            if skip_scene:
                continue

        # for debugging purposes:
        # SceneCreator.graph_for_each(scene_elements.object_graph, print_obj_names)

        # bind visualizer and properties to top level rendering function:
        bound_render_obj = partial(render_obj, viz,
                                   scene_elements.property_map,
                                   previous_model_positions)
        # bind visualizer and properties to nested obj rendering function
        bound_render_nested_obj = partial(render_obj_nested, viz,
                                          scene_elements.property_map,
                                          previous_model_positions)

        # render each object in graph
        SceneCreator.graph_for_each_top_level(scene_elements.object_graph,
                                              bound_render_obj,
                                              bound_render_nested_obj)

        # apply scale to top level nodes in scene
        for node in scene_elements.object_graph:
            if (node.name not in OBJECT_NAMES_TO_EXCLUDE and
                    node.name.split("_")[0] in OBJECT_SCALE_MULTIPLIER_MAP):
                viz.multiply_scale(
                    node.name,
                    OBJECT_SCALE_MULTIPLIER_MAP[node.name.split("_")[0]])

        # find the Region relations that refer to separate objects:
        # (e.g. the cookie is in the region of the hand (of the person), not the leg-segment in in the region of the torso).
        inter_object_in_region_map: DefaultDict[
            ObjectPerception,
            List[Region[ObjectPerception]]] = defaultdict(list)
        for top_level_node in scene_elements.object_graph:
            if top_level_node.perceived_obj in scene_elements.in_region_map:
                inter_object_in_region_map[
                    top_level_node.
                    perceived_obj] = scene_elements.in_region_map[
                        top_level_node.perceived_obj]

        # print(inter_object_in_region_map)

        # we want to assemble a lookup of the offsets (position) of each object's subobjects.
        sub_object_offsets = {}

        for node_name, node in viz.geo_nodes.items():
            child_node_to_offset = {}

            recurse_list: List[NodePath] = node.children
            while recurse_list:
                next_batch: List[NodePath] = []
                for child in recurse_list:
                    next_batch += child.children
                    # make sure this is a sub-object
                    if child.hasMat() and child.parent.name != node_name:
                        # child has non-identity transformation matrix applied to it (transform differs from parent)
                        # TODO: we could re-export all of the models in such a way to eliminate this extra layer
                        #       in the scene graph
                        child_node_to_offset[
                            child.parent.name] = child.get_pos()
                recurse_list = next_batch

            sub_object_offsets[node_name] = child_node_to_offset

        # handle skipping scene
        if not automatically_save_renderings:
            viz.run_for_seconds(1)
            skip_command = input("type 's' and hit ENTER to skip this scene")
            if skip_command == "s":
                viz.clear_scene()
                viz.run_for_seconds(0.25)
                continue

        handle_to_in_region_map = {
            object_perception.debug_handle: region_list
            for object_perception, region_list in
            inter_object_in_region_map.items()
        }

        frozen_objects = objects_to_freeze(
            handle_to_in_region_map,
            scene_elements.situation,
            scene_elements.situation_object_to_handle,
        )

        if scene_elements.interpolated_scene_moving_items:
            # freeze everything not included in the interpolated scene
            frozen_objects = (immutableset([
                key.debug_handle
                for key in scene_elements.in_region_map.keys()
            ]) - scene_elements.interpolated_scene_moving_items)

        # now that every object has been instantiated into the scene,
        # they need to be re-positioned.

        repositioned_map = None

        for repositioned_map in _solve_top_level_positions(
                top_level_objects=immutableset([
                    node.perceived_obj for node in scene_elements.object_graph
                    if node.name not in OBJECT_NAMES_TO_EXCLUDE
                ]),
                sub_object_offsets=sub_object_offsets,
                in_region_map=inter_object_in_region_map,
                model_scales=model_scales,
                frozen_objects=frozen_objects,
                iterations=num_iterations,
                yield_steps=steps_before_vis,
                previous_positions=previous_model_positions,
        ):
            viz.clear_debug_nodes()
            viz.clear_gaze_arrows()
            if not automatically_save_renderings:
                viz.run_for_seconds(0.25)

            viz.set_positions(repositioned_map)

            if debug_bounding_boxes:
                for name in repositioned_map.name_to_position:
                    viz.add_debug_bounding_box(
                        name,
                        repositioned_map.name_to_position[name],
                        repositioned_map.name_to_scale[name],
                    )

            if gaze_arrows:
                for handle, props in scene_elements.property_map.items():
                    for prop in props:
                        if isinstance(
                                prop,
                                OntologyNode) and prop.handle == "gazed-at":
                            viz.add_gaze_arrow(
                                handle,
                                repositioned_map.name_to_position[handle],
                                repositioned_map.name_to_scale[handle],
                            )
            # the visualizer seems to need about a second to render an update
            if not automatically_save_renderings:
                viz.run_for_seconds(1)
                # viz.print_scene_graph()
            previous_model_positions = None

        # only store previous positions when continuing to next frame / scene
        previous_model_positions = repositioned_map
        viz.run_for_seconds(1)

        screenshot(
            automatically_save_renderings=automatically_save_renderings,
            filename=scene_filename,
            screenshot_dir=screenshot_dir,
            viz=viz,
        )

        viz.clear_scene()
        viz.run_for_seconds(0.25)
Beispiel #5
0
def learner_factory_from_params(
    params: Parameters,
    graph_logger: Optional[HypothesisLogger],
    language_mode: LanguageMode = LanguageMode.ENGLISH,
) -> Callable[[], TopLevelLanguageLearner]:  # type: ignore
    learner_type = params.string(
        "learner",
        [
            "pursuit",
            "object-subset",
            "preposition-subset",
            "attribute-subset",
            "verb-subset",
            "integrated-learner",
            "integrated-learner-recognizer-without-generics",
            "integrated-learner-recognizer",
            "pursuit-gaze",
            "integrated-object-only",
            "integrated-learner-params",
            "integrated-pursuit-attribute-only",
        ],
    )

    beam_size = params.positive_integer("beam_size", default=10)
    rng = random.Random()
    rng.seed(0)
    perception_generator = GAILA_PHASE_1_PERCEPTION_GENERATOR

    objects = [YOU_HACK, ME_HACK]
    objects.extend(PHASE_1_CURRICULUM_OBJECTS)

    # Eval hack! This is specific to the Phase 1 ontology
    object_recognizer = ObjectRecognizer.for_ontology_types(
        objects,
        determiners=ENGLISH_DETERMINERS,
        ontology=GAILA_PHASE_1_ONTOLOGY,
        language_mode=language_mode,
        perception_generator=perception_generator,
    )

    if learner_type == "pursuit":
        return lambda: ObjectPursuitLearner.from_parameters(
            params.namespace("pursuit"), graph_logger=graph_logger)
    elif learner_type == "pursuit-gaze":
        return lambda: IntegratedTemplateLearner(
            object_learner=PursuitObjectLearnerNew(
                learning_factor=0.05,
                graph_match_confirmation_threshold=0.7,
                lexicon_entry_threshold=0.7,
                rng=rng,
                smoothing_parameter=0.002,
                ontology=GAILA_PHASE_2_ONTOLOGY,
                language_mode=language_mode,
                rank_gaze_higher=True,
            ),
            attribute_learner=SubsetAttributeLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            relation_learner=SubsetRelationLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            action_learner=SubsetVerbLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
        )
    elif learner_type == "object-subset":
        return lambda: SubsetObjectLearner(ontology=GAILA_PHASE_1_ONTOLOGY,
                                           language_mode=LanguageMode.ENGLISH)
    elif learner_type == "attribute-subset":
        return lambda: SubsetAttributeLearner(
            ontology=GAILA_PHASE_1_ONTOLOGY,
            object_recognizer=object_recognizer,
            language_mode=LanguageMode.ENGLISH,
        )
    elif learner_type == "preposition-subset":
        return lambda: SubsetPrepositionLearner(
            # graph_logger=graph_logger,
            object_recognizer=object_recognizer,
            ontology=GAILA_PHASE_1_ONTOLOGY,
            language_mode=LanguageMode.ENGLISH,
        )
    elif learner_type == "verb-subset":
        return lambda: SubsetVerbLearner(
            ontology=GAILA_PHASE_1_ONTOLOGY,
            object_recognizer=object_recognizer,
            language_mode=LanguageMode.ENGLISH,
        )
    elif learner_type == "integrated-learner":
        return lambda: IntegratedTemplateLearner(
            object_learner=SubsetObjectLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            attribute_learner=SubsetAttributeLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            relation_learner=SubsetRelationLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            action_learner=SubsetVerbLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            functional_learner=FunctionalLearner(language_mode=language_mode),
        )
    elif learner_type == "integrated-learner-recognizer":
        return lambda: IntegratedTemplateLearner(
            object_learner=ObjectRecognizerAsTemplateLearner(
                object_recognizer=object_recognizer,
                language_mode=language_mode),
            attribute_learner=SubsetAttributeLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            relation_learner=SubsetRelationLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            action_learner=SubsetVerbLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            functional_learner=FunctionalLearner(language_mode=language_mode),
            generics_learner=SimpleGenericsLearner(),
        )
    elif learner_type == "ic":
        return lambda: IntegratedTemplateLearner(
            object_learner=ObjectRecognizerAsTemplateLearner(
                object_recognizer=object_recognizer,
                language_mode=language_mode),
            attribute_learner=SubsetAttributeLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            relation_learner=SubsetRelationLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            action_learner=SubsetVerbLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            functional_learner=FunctionalLearner(language_mode=language_mode),
        )
    elif learner_type == "integrated-object-only":
        object_learner_type = params.string(
            "object_learner_type",
            valid_options=["subset", "pbv", "pursuit"],
            default="subset",
        )

        if params.has_namespace("learner_params"):
            learner_params = params.namespace("learner_params")
        else:
            learner_params = params.empty(namespace_prefix="learner_params")

        object_learner_factory: Callable[[], TemplateLearner]
        if object_learner_type == "subset":

            def subset_factory() -> SubsetObjectLearnerNew:
                return SubsetObjectLearnerNew(  # type: ignore
                    ontology=GAILA_PHASE_2_ONTOLOGY,
                    beam_size=beam_size,
                    language_mode=language_mode,
                )

            object_learner_factory = subset_factory

        elif object_learner_type == "pbv":

            def pbv_factory() -> ProposeButVerifyObjectLearner:
                return ProposeButVerifyObjectLearner.from_params(  # type: ignore
                    learner_params)

            object_learner_factory = pbv_factory
        elif object_learner_type == "pursuit":

            def pursuit_factory() -> PursuitObjectLearnerNew:
                return PursuitObjectLearnerNew(  # type: ignore
                    learning_factor=learner_params.floating_point(
                        "learning_factor"),
                    graph_match_confirmation_threshold=learner_params.
                    floating_point("graph_match_confirmation_threshold"),
                    lexicon_entry_threshold=learner_params.floating_point(
                        "lexicon_entry_threshold"),
                    rng=rng,
                    smoothing_parameter=learner_params.floating_point(
                        "smoothing_parameter"),
                    ontology=GAILA_PHASE_2_ONTOLOGY,
                    language_mode=language_mode,
                )

            object_learner_factory = pursuit_factory
        else:
            raise RuntimeError(
                f"Invalid Object Learner Type Selected: {learner_type}")
        return lambda: IntegratedTemplateLearner(object_learner=
                                                 object_learner_factory())
    elif learner_type == "integrated-learner-params":
        object_learner = build_object_learner_factory(  # type:ignore
            params.namespace_or_empty("object_learner"), beam_size,
            language_mode)
        attribute_learner = build_attribute_learner_factory(  # type:ignore
            params.namespace_or_empty("attribute_learner"), beam_size,
            language_mode)
        relation_learner = build_relation_learner_factory(  # type:ignore
            params.namespace_or_empty("relation_learner"), beam_size,
            language_mode)
        action_learner = build_action_learner_factory(  # type:ignore
            params.namespace_or_empty("action_learner"), beam_size,
            language_mode)
        plural_learner = build_plural_learner_factory(  # type:ignore
            params.namespace_or_empty("plural_learner"), beam_size,
            language_mode)
        return lambda: IntegratedTemplateLearner(
            object_learner=object_learner,
            attribute_learner=attribute_learner,
            relation_learner=relation_learner,
            action_learner=action_learner,
            functional_learner=FunctionalLearner(language_mode=language_mode)
            if params.boolean("include_functional_learner", default=True) else
            None,
            generics_learner=SimpleGenericsLearner() if params.boolean(
                "include_generics_learner", default=True) else None,
            plural_learner=plural_learner,
            suppress_error=params.boolean("suppress_error", default=True),
        )
    elif learner_type == "integrated-pursuit-attribute-only":
        return lambda: IntegratedTemplateLearner(
            object_learner=ObjectRecognizerAsTemplateLearner(
                object_recognizer=object_recognizer,
                language_mode=language_mode),
            attribute_learner=PursuitAttributeLearnerNew(
                learning_factor=0.05,
                graph_match_confirmation_threshold=0.7,
                lexicon_entry_threshold=0.7,
                rng=rng,
                smoothing_parameter=0.002,
                rank_gaze_higher=False,
                ontology=GAILA_PHASE_1_ONTOLOGY,
                language_mode=language_mode,
            ),
        )
    else:
        raise RuntimeError("can't happen")
Beispiel #6
0
def learner_factory_from_params(
    params: Parameters,
    graph_logger: Optional[HypothesisLogger],
    language_mode: LanguageMode = LanguageMode.ENGLISH,
) -> Callable[[], TopLevelLanguageLearner]:  # type: ignore
    learner_type = params.string(
        "learner",
        [
            "pursuit",
            "object-subset",
            "preposition-subset",
            "attribute-subset",
            "verb-subset",
            "integrated-learner",
            "integrated-learner-recognizer",
            "pursuit-gaze",
        ],
    )

    beam_size = params.positive_integer("beam_size", default=10)

    if language_mode == LanguageMode.CHINESE and learner_type not in [
            "integrated-learner",
            "integrated-learner-recognizer",
    ]:
        raise RuntimeError(
            "Only able to test Chinese with integrated learner.")

    rng = random.Random()
    rng.seed(0)
    perception_generator = GAILA_PHASE_1_PERCEPTION_GENERATOR

    objects = [YOU_HACK, ME_HACK]
    objects.extend(PHASE_1_CURRICULUM_OBJECTS)

    # Eval hack! This is specific to the Phase 1 ontology
    object_recognizer = ObjectRecognizer.for_ontology_types(
        objects,
        determiners=ENGLISH_DETERMINERS,
        ontology=GAILA_PHASE_1_ONTOLOGY,
        language_mode=language_mode,
        perception_generator=perception_generator,
    )

    if learner_type == "pursuit":
        return lambda: ObjectPursuitLearner.from_parameters(
            params.namespace("pursuit"), graph_logger=graph_logger)
    elif learner_type == "pursuit-gaze":
        return lambda: IntegratedTemplateLearner(
            object_learner=PursuitObjectLearnerNew(
                learning_factor=0.05,
                graph_match_confirmation_threshold=0.7,
                lexicon_entry_threshold=0.7,
                rng=rng,
                smoothing_parameter=0.002,
                ontology=GAILA_PHASE_2_ONTOLOGY,
                language_mode=language_mode,
                rank_gaze_higher=True,
            ),
            attribute_learner=SubsetAttributeLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            relation_learner=SubsetRelationLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            action_learner=SubsetVerbLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
        )
    elif learner_type == "object-subset":
        return lambda: SubsetObjectLearner(ontology=GAILA_PHASE_1_ONTOLOGY,
                                           language_mode=LanguageMode.ENGLISH)
    elif learner_type == "attribute-subset":
        return lambda: SubsetAttributeLearner(
            ontology=GAILA_PHASE_1_ONTOLOGY,
            object_recognizer=object_recognizer,
            language_mode=LanguageMode.ENGLISH,
        )
    elif learner_type == "preposition-subset":
        return lambda: SubsetPrepositionLearner(
            # graph_logger=graph_logger,
            object_recognizer=object_recognizer,
            ontology=GAILA_PHASE_1_ONTOLOGY,
            language_mode=LanguageMode.ENGLISH,
        )
    elif learner_type == "verb-subset":
        return lambda: SubsetVerbLearner(
            ontology=GAILA_PHASE_1_ONTOLOGY,
            object_recognizer=object_recognizer,
            language_mode=LanguageMode.ENGLISH,
        )
    elif learner_type == "integrated-learner":
        return lambda: IntegratedTemplateLearner(
            object_learner=SubsetObjectLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            attribute_learner=SubsetAttributeLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            relation_learner=SubsetRelationLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            action_learner=SubsetVerbLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            functional_learner=FunctionalLearner(language_mode=language_mode),
        )
    elif learner_type == "integrated-learner-recognizer":
        return lambda: IntegratedTemplateLearner(
            object_learner=ObjectRecognizerAsTemplateLearner(
                object_recognizer=object_recognizer,
                language_mode=language_mode),
            attribute_learner=SubsetAttributeLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            relation_learner=SubsetRelationLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            action_learner=SubsetVerbLearnerNew(
                ontology=GAILA_PHASE_2_ONTOLOGY,
                beam_size=beam_size,
                language_mode=language_mode,
            ),
            functional_learner=FunctionalLearner(language_mode=language_mode),
        )
    else:
        raise RuntimeError("can't happen")