コード例 #1
0
    def __init__(
            self,
            entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
            train_entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 14),
            validation_entity_counts=(13, ),
            test_entity_counts=(15, ),
            validation_combinations=(('square', 'red',
                                      'solid'), ('triangle', 'green', 'solid'),
                                     ('circle', 'blue', 'solid')),
            test_combinations=(('rectangle', 'yellow',
                                'solid'), ('cross', 'magenta', 'solid'),
                               ('ellipse', 'cyan', 'solid')),
            caption_size=12,
            vocabulary=('.', 'a', 'all', 'an', 'are', 'blue', 'circle',
                        'circles', 'cross', 'crosses', 'cyan', 'eight',
                        'ellipse', 'ellipses', 'few', 'five', 'four', 'gray',
                        'green', 'half', 'is', 'magenta', 'most', 'no', 'none',
                        'of', 'pentagon', 'pentagons', 'quarter', 'quarters',
                        'rectangle', 'rectangles', 'red', 'semicircle',
                        'semicircles', 'seven', 'shape', 'shapes', 'six',
                        'square', 'squares', 'the', 'third', 'thirds', 'three',
                        'triangle', 'triangles', 'two', 'yellow'),
            language=None):

        # world_generator = LimitedAttributesGenerator(
        #     shapes_range=(2, 4),
        #     colors_range=(2, 4),
        #     textures_range=(1, 1),
        world_generator = ReinforcedAttributesGenerator(
            reinforcement_range=(1, 3),
            entity_counts=entity_counts,
            collision_tolerance=0.0,
            boundary_tolerance=0.0,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            test_entity_counts=test_entity_counts,
            validation_combinations=validation_combinations,
            test_combinations=test_combinations,
            max_provoke_collision_rate=0.0)

        quantifier_captioner = QuantifierCaptioner(
            restrictor_captioner=RegularTypeCaptioner(
                hypernym_rate=1.0, logical_tautology_rate=1.0),
            body_captioner=AttributeTypeRelationCaptioner(
                attribute_type_captioner=CaptionerMixer(
                    captioners=(RegularAttributeCaptioner(),
                                RegularTypeCaptioner()))),
            quantifiers=('ratio', ))
        number_bound_captioner = NumberBoundCaptioner(
            quantifier_captioner=quantifier_captioner)
        world_captioner = CaptionerMixer(captioners=(quantifier_captioner,
                                                     number_bound_captioner),
                                         distribution=[1, 1])

        super(QuantificationRatioSimple,
              self).__init__(world_generator=world_generator,
                             world_captioner=world_captioner,
                             caption_size=caption_size,
                             vocabulary=vocabulary,
                             language=language)
コード例 #2
0
    def __init__(self,
                 world_size=64,
                 entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
                 train_entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 14),
                 validation_entity_counts=(13, ),
                 test_entity_counts=(15, ),
                 max_provoke_collision_rate=0.33,
                 collision_tolerance=0.2,
                 boundary_tolerance=0.2,
                 pixel_noise_stddev=0.0):

        world_generator = ReinforcedAttributesGenerator(
            reinforcement_range=(1, 3),
            entity_counts=entity_counts,
            world_size=world_size,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            test_entity_counts=test_entity_counts,
            max_provoke_collision_rate=max_provoke_collision_rate,
            collision_tolerance=collision_tolerance,
            boundary_tolerance=boundary_tolerance)

        num_classes = len(world_generator.shapes) * len(
            world_generator.colors) * len(world_generator.textures)

        super(Countshape, self).__init__(world_generator=world_generator,
                                         num_classes=num_classes,
                                         multi_class=True,
                                         class_count=True,
                                         pixel_noise_stddev=pixel_noise_stddev)
コード例 #3
0
    def __init__(
        self,
        entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
        train_entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 14),
        validation_entity_counts=(13,),
        test_entity_counts=(15,),
        validation_combinations=(('square', 'red', 'solid'), ('triangle', 'green', 'solid'), ('circle', 'blue', 'solid')),
        test_combinations=(('rectangle', 'yellow', 'solid'), ('cross', 'magenta', 'solid'), ('ellipse', 'cyan', 'solid')),
        caption_size=14,
        vocabulary=('.', 'a', 'an', 'biggest', 'blue', 'circle', 'cross', 'cyan', 'darkest', 'ellipse', 'gray', 'green', 'is', 'leftmost', 'lightest', 'lowermost', 'magenta', 'most', 'pentagon', 'rectangle', 'red', 'rightmost', 'semicircle', 'shape', 'smallest', 'square', 'topmost', 'triangle', 'yellow'),
        language=None
    ):

        world_generator = ReinforcedAttributesGenerator(
            reinforcement_range=(1, 1),
            entity_counts=entity_counts,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            test_entity_counts=test_entity_counts,
            validation_combinations=validation_combinations,
            test_combinations=test_combinations
        )

        world_captioner = ExistentialCaptioner(
            restrictor_captioner=MaxAttributeCaptioner(
                scope_captioner=RegularTypeCaptioner(
                    hypernym_rate=1.0,
                    logical_tautology_rate=1.0
                )
            ),
            body_captioner=AttributeTypeRelationCaptioner(
                attribute_type_captioner=CaptionerMixer(
                    captioners=(
                        RegularAttributeCaptioner(),
                        RegularTypeCaptioner()
                    )
                )
            )
        )

        super(Maxattr, self).__init__(
            world_generator=world_generator,
            world_captioner=world_captioner,
            caption_size=caption_size,
            vocabulary=vocabulary,
            language=language
        )
コード例 #4
0
    def __init__(
        self,
        entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
        train_entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 14),
        validation_entity_counts=(13,),
        test_entity_counts=(15,),
        validation_combinations=(('square', 'red', 'solid'), ('triangle', 'green', 'solid'), ('circle', 'blue', 'solid')),
        test_combinations=(('rectangle', 'yellow', 'solid'), ('cross', 'magenta', 'solid'), ('ellipse', 'cyan', 'solid')),
        caption_size=14,
        vocabulary=('.', 'a', 'above', 'an', 'behind', 'below', 'bigger', 'biggest', 'blue', 'circle', 'closer', 'closest', 'cross', 'cyan', 'darker', 'darkest', 'ellipse', 'farther', 'farthest', 'from', 'front', 'gray', 'green', 'in', 'is', 'left', 'leftmost', 'lighter', 'lightest', 'lowermost', 'magenta', 'most', 'of', 'pentagon', 'rectangle', 'red', 'right', 'rightmost', 'semicircle', 'shape', 'smaller', 'smallest', 'square', 'than', 'the', 'to', 'topmost', 'triangle', 'yellow'),
        language=None
    ):

        world_generator = ReinforcedAttributesGenerator(
            reinforcement_range=(1, 1),
            entity_counts=entity_counts,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            test_entity_counts=test_entity_counts,
            validation_combinations=validation_combinations,
            test_combinations=test_combinations
        )

        world_captioner = ExistentialCaptioner(
            restrictor_captioner=RegularTypeCaptioner(),
            body_captioner=RelationCaptioner(
                reference_captioner=RegularTypeCaptioner(),
                comparison_captioner=RegularTypeCaptioner()
            )
        )

        super(Relational, self).__init__(
            world_generator=world_generator,
            world_captioner=world_captioner,
            caption_size=caption_size,
            vocabulary=vocabulary,
            language=language
        )
コード例 #5
0
    def __init__(
            self,
            world_size=64,
            world_color='black',
            shapes=('square', 'rectangle', 'triangle', 'pentagon', 'cross',
                    'circle', 'semicircle', 'ellipse'),
            colors=('red', 'green', 'blue', 'yellow', 'magenta', 'cyan',
                    'gray'),
            textures=('solid', ),
            rotation=True,
            size_range=(0.1, 0.25),
            distortion_range=(2.0, 3.0),
            shade_range=0.4,
            collision_tolerance=0.25,
            collision_shade_difference=0.5,
            boundary_tolerance=None,
            entity_counts=None,
            train_entity_counts=(5, 6, 7, 9, 11, 12, 14),
            validation_entity_counts=(8, 13),
            test_entity_counts=(10, 15),
            validation_count_rate=0.5,
            test_count_rate=0.5,
            validation_combinations=(('square', 'red',
                                      'solid'), ('triangle', 'green', 'solid'),
                                     ('circle', 'blue', 'solid')),
            test_combinations=(('rectangle', 'yellow',
                                'solid'), ('cross', 'magenta', 'solid'),
                               ('ellipse', 'cyan', 'solid')),
            validation_space_rate_range=(0.0, 1.0),
            test_space_rate_range=(0.0, 1.0),
            validation_combination_rate=0.5,
            test_combination_rate=0.5,
            max_provoke_collision_rate=0.33,
            reinforcement_range=(1, 3),
            quantifiers=None,
            number_bounds=None,
            comparative_quantifiers=None,
            caption_size=15,
            vocabulary=('.', 'a', 'all', 'an', 'are', 'as', 'at', 'blue',
                        'but', 'circle', 'circles', 'cross', 'crosses', 'cyan',
                        'eight', 'ellipse', 'ellipses', 'exactly', 'five',
                        'four', 'gray', 'green', 'half', 'is', 'least', 'less',
                        'magenta', 'many', 'more', 'most', 'no', 'none', 'not',
                        'of', 'one', 'pentagon', 'pentagons', 'quarter',
                        'quarters', 'rectangle', 'rectangles', 'red',
                        'semicircle', 'semicircles', 'seven', 'shape',
                        'shapes', 'six', 'square', 'squares', 'than', 'the',
                        'third', 'thirds', 'three', 'triangle', 'triangles',
                        'twice', 'two', 'yellow', 'zero'),
            correct_ratio=0.5,
            train_correct_ratio=None,
            validation_correct_ratio=None,
            test_correct_ratio=None,
            worlds_per_instance=1,
            captions_per_instance=1,
            pixel_noise_stddev=0.0,
            caption_realizer='dmrs',
            language=None):

        world_generator = ReinforcedAttributesGenerator(
            world_size=world_size,
            world_color=world_color,
            shapes=shapes,
            colors=colors,
            textures=textures,
            rotation=rotation,
            size_range=size_range,
            distortion_range=distortion_range,
            shade_range=shade_range,
            collision_tolerance=collision_tolerance,
            collision_shade_difference=collision_shade_difference,
            boundary_tolerance=boundary_tolerance,
            entity_counts=entity_counts,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            validation_count_rate=validation_count_rate,
            test_entity_counts=test_entity_counts,
            test_count_rate=test_count_rate,
            validation_combinations=validation_combinations,
            validation_space_rate_range=validation_space_rate_range,
            validation_combination_rate=validation_combination_rate,
            test_combinations=test_combinations,
            test_space_rate_range=test_space_rate_range,
            test_combination_rate=test_combination_rate,
            max_provoke_collision_rate=max_provoke_collision_rate,
            reinforcement_range=reinforcement_range)

        body_captioner = AttributeTypeRelationCaptioner(
            attribute_type_captioner=CaptionerMixer(
                captioners=(RegularAttributeCaptioner(),
                            RegularTypeCaptioner(hypernym_rate=0.0))))

        quantifier_captioner = QuantifierCaptioner(
            restrictor_captioner=RegularTypeCaptioner(hypernym_rate=1.0),
            body_captioner=body_captioner,
            quantifiers=quantifiers)

        number_bound_captioner = NumberBoundCaptioner(
            quantifier_captioner=quantifier_captioner,
            number_bounds=number_bounds)

        comparative_quantifier_captioner = ComparativeQuantifierCaptioner(
            restrictor_captioner=RegularTypeCaptioner(hypernym_rate=1.0),
            comparison_captioner=RegularTypeCaptioner(hypernym_rate=1.0),
            body_captioner=body_captioner,
            comparative_quantifiers=comparative_quantifiers)

        world_captioner = CaptionerMixer(
            captioners=(quantifier_captioner, number_bound_captioner,
                        comparative_quantifier_captioner),
            distribution=(1, 1, 1))

        super(QuantificationComplexDataset,
              self).__init__(world_generator=world_generator,
                             world_captioner=world_captioner,
                             caption_size=caption_size,
                             vocabulary=vocabulary,
                             correct_ratio=correct_ratio,
                             train_correct_ratio=train_correct_ratio,
                             validation_correct_ratio=validation_correct_ratio,
                             test_correct_ratio=test_correct_ratio,
                             worlds_per_instance=worlds_per_instance,
                             captions_per_instance=captions_per_instance,
                             pixel_noise_stddev=pixel_noise_stddev,
                             caption_realizer=caption_realizer,
                             language=language)
コード例 #6
0
ファイル: logical.py プロジェクト: MannyKayy/ShapeWorld
    def __init__(
            self,
            world_size=64,
            world_color='black',
            shapes=('square', 'rectangle', 'triangle', 'pentagon', 'cross',
                    'circle', 'semicircle', 'ellipse'),
            colors=('red', 'green', 'blue', 'yellow', 'magenta', 'cyan',
                    'gray'),
            textures=('solid', ),
            rotation=True,
            size_range=(0.1, 0.25),
            distortion_range=(2.0, 3.0),
            shade_range=0.4,
            collision_tolerance=0.25,
            collision_shade_difference=0.5,
            boundary_tolerance=None,
            entity_counts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
            train_entity_counts=None,
            validation_entity_counts=None,
            test_entity_counts=None,
            validation_count_rate=0.5,
            test_count_rate=0.5,
            validation_combinations=None,
            test_combinations=None,
            validation_space_rate_range=(0.0, 1.0),
            test_space_rate_range=(0.0, 1.0),
            validation_combination_rate=0.5,
            test_combination_rate=0.5,
            max_provoke_collision_rate=0.33,
            reinforcement_range=(1, 3),
            generators=None,
            captioners=None,
            connectives=None,
            caption_size=28,
            vocabulary=('.', 'a', 'above', 'all', 'an', 'and', 'are', 'as',
                        'at', 'behind', 'below', 'bigger', 'blue', 'but',
                        'circle', 'circles', 'closer', 'color', 'cross',
                        'crosses', 'cyan', 'darker', 'different', 'eight',
                        'either', 'ellipse', 'ellipses', 'exactly', 'farther',
                        'few', 'five', 'four', 'from', 'front', 'gray',
                        'green', 'half', 'if', 'in', 'is', 'least', 'left',
                        'less', 'lighter', 'magenta', 'many', 'more', 'most',
                        'no', 'none', 'not', 'of', 'one', 'only', 'or',
                        'pentagon', 'pentagons', 'quarter', 'quarters',
                        'rectangle', 'rectangles', 'red', 'right', 'same',
                        'semicircle', 'semicircles', 'seven', 'shape',
                        'shapes', 'six', 'smaller', 'square', 'squares',
                        'than', 'the', 'there', 'third', 'thirds', 'three',
                        'to', 'triangle', 'triangles', 'twice', 'two',
                        'yellow', 'zero'),
            correct_ratio=0.5,
            train_correct_ratio=None,
            validation_correct_ratio=None,
            test_correct_ratio=None,
            worlds_per_instance=1,
            captions_per_instance=1,
            pixel_noise_stddev=0.0,
            caption_realizer='dmrs',
            language=None):

        generator_list = list()

        if generators is None or 'random' in generators:
            random_generator = RandomAttributesGenerator(
                world_size=world_size,
                world_color=world_color,
                shapes=shapes,
                colors=colors,
                textures=textures,
                rotation=rotation,
                size_range=size_range,
                distortion_range=distortion_range,
                shade_range=shade_range,
                collision_tolerance=collision_tolerance,
                collision_shade_difference=collision_shade_difference,
                boundary_tolerance=boundary_tolerance,
                entity_counts=entity_counts,
                train_entity_counts=train_entity_counts,
                validation_entity_counts=validation_entity_counts,
                validation_count_rate=validation_count_rate,
                test_entity_counts=test_entity_counts,
                test_count_rate=test_count_rate,
                validation_combinations=validation_combinations,
                validation_space_rate_range=validation_space_rate_range,
                validation_combination_rate=validation_combination_rate,
                test_combinations=test_combinations,
                test_space_rate_range=test_space_rate_range,
                test_combination_rate=test_combination_rate,
                max_provoke_collision_rate=max_provoke_collision_rate)
            generator_list.append(random_generator)

        if generators is None or 'reinforced' in generators:
            reinforced_generator = ReinforcedAttributesGenerator(
                world_size=world_size,
                world_color=world_color,
                shapes=shapes,
                colors=colors,
                textures=textures,
                rotation=rotation,
                size_range=size_range,
                distortion_range=distortion_range,
                shade_range=shade_range,
                collision_tolerance=collision_tolerance,
                collision_shade_difference=collision_shade_difference,
                boundary_tolerance=boundary_tolerance,
                entity_counts=entity_counts,
                train_entity_counts=train_entity_counts,
                validation_entity_counts=validation_entity_counts,
                validation_count_rate=validation_count_rate,
                test_entity_counts=test_entity_counts,
                test_count_rate=test_count_rate,
                validation_combinations=validation_combinations,
                validation_space_rate_range=validation_space_rate_range,
                validation_combination_rate=validation_combination_rate,
                test_combinations=test_combinations,
                test_space_rate_range=test_space_rate_range,
                test_combination_rate=test_combination_rate,
                max_provoke_collision_rate=max_provoke_collision_rate,
                reinforcement_range=reinforcement_range)
            generator_list.append(reinforced_generator)

        world_generator = GeneratorMixer(generators=generator_list)

        restrictor_captioner = CaptionerMixer(
            captioners=(EmptyTypeCaptioner(),
                        RegularTypeCaptioner(hypernym_rate=1.0)))

        body_captioner = AttributeTypeRelationCaptioner(
            attribute_type_captioner=CaptionerMixer(
                captioners=(RegularAttributeCaptioner(),
                            RegularTypeCaptioner(hypernym_rate=0.0))))

        captioner_list = list()

        if captioners is None or 'existential' in captioners:
            existential_captioner = CaptionerMixer(
                captioners=(RegularTypeCaptioner(),
                            ExistentialCaptioner(
                                restrictor_captioner=restrictor_captioner,
                                body_captioner=body_captioner)),
                distribution=(1, 2))
            captioner_list.append(existential_captioner)

        if captioners is None or 'relational' in captioners:
            relational_captioner = ExistentialCaptioner(
                restrictor_captioner=RegularTypeCaptioner(),
                body_captioner=RelationCaptioner(
                    reference_captioner=RegularTypeCaptioner(),
                    comparison_captioner=RegularTypeCaptioner()))
            captioner_list.append(relational_captioner)

        if captioners is None or 'quantification' in captioners:
            quantification_captioner = QuantifierCaptioner(
                restrictor_captioner=restrictor_captioner,
                body_captioner=body_captioner)
            captioner_list.append(quantification_captioner)

        captioner = CaptionerMixer(captioners=captioner_list)

        captioner_list = list()

        if connectives is None or 'conjunction' in connectives:
            captioner_list.append(ConjunctionCaptioner(captioner=captioner))

        if connectives is None or 'disjunction' in connectives:
            captioner_list.append(DisjunctionCaptioner(captioner=captioner))

        if connectives is None or 'implication' in connectives:
            captioner_list.append(ImplicationCaptioner(captioner=captioner))

        if connectives is None or 'equivalence' in connectives:
            captioner_list.append(EquivalenceCaptioner(captioner=captioner))

        world_captioner = CaptionerMixer(captioners=captioner_list)

        super(LogicalDataset,
              self).__init__(world_generator=world_generator,
                             world_captioner=world_captioner,
                             caption_size=caption_size,
                             vocabulary=vocabulary,
                             correct_ratio=correct_ratio,
                             train_correct_ratio=train_correct_ratio,
                             validation_correct_ratio=validation_correct_ratio,
                             test_correct_ratio=test_correct_ratio,
                             worlds_per_instance=worlds_per_instance,
                             captions_per_instance=captions_per_instance,
                             pixel_noise_stddev=pixel_noise_stddev,
                             caption_realizer=caption_realizer,
                             language=language)
コード例 #7
0
    def __init__(self,
                 world_size=64,
                 world_colors=('black', ),
                 shapes=('square', 'rectangle', 'triangle', 'pentagon',
                         'cross', 'circle', 'semicircle', 'ellipse'),
                 colors=('red', 'green', 'blue', 'yellow', 'magenta', 'cyan',
                         'gray'),
                 textures=('solid', ),
                 rotation=True,
                 size_range=(0.1, 0.25),
                 distortion_range=(2.0, 3.0),
                 shade_range=0.4,
                 collision_tolerance=0.25,
                 collision_shade_difference=0.5,
                 boundary_tolerance=None,
                 entity_counts=(3, 4, 5, 6, 7, 8, 9, 10),
                 train_entity_counts=None,
                 validation_entity_counts=None,
                 test_entity_counts=None,
                 validation_count_rate=0.5,
                 test_count_rate=0.5,
                 validation_combinations=None,
                 test_combinations=None,
                 validation_space_rate_range=(0.0, 1.0),
                 test_space_rate_range=(0.0, 1.0),
                 validation_combination_rate=0.5,
                 test_combination_rate=0.5,
                 max_provoke_collision_rate=0.33,
                 relations=None,
                 negation=True,
                 existential_incorrect_distribution=(1, 1),
                 relation_incorrect_distribution=(2, 1, 1),
                 type_existing_attribute_rate=1.0,
                 type_incorrect_distribution=(1, 1, 1, 1),
                 caption_size=15,
                 vocabulary=('.', 'a', 'above', 'an', 'as', 'behind', 'below',
                             'besides', 'bigger', 'blue', 'circle', 'closer',
                             'color', 'cross', 'cyan', 'darker', 'different',
                             'does', 'ellipse', 'exist', 'exists', 'farther',
                             'from', 'front', 'gray', 'green', 'in', 'is',
                             'left', 'lighter', 'magenta', 'not', 'of',
                             'pentagon', 'rectangle', 'red', 'right', 'same',
                             'semicircle', 'shape', 'smaller', 'square',
                             'than', 'the', 'to', 'triangle', 'yellow'),
                 correct_ratio=0.5,
                 train_correct_ratio=None,
                 validation_correct_ratio=None,
                 test_correct_ratio=None,
                 worlds_per_instance=1,
                 captions_per_instance=1,
                 pixel_noise_stddev=None,
                 caption_realizer='dmrs',
                 language=None):

        world_generator = ReinforcedAttributesGenerator(
            world_size=world_size,
            world_colors=world_colors,
            shapes=shapes,
            colors=colors,
            textures=textures,
            rotation=rotation,
            size_range=size_range,
            distortion_range=distortion_range,
            shade_range=shade_range,
            collision_tolerance=collision_tolerance,
            collision_shade_difference=collision_shade_difference,
            boundary_tolerance=boundary_tolerance,
            entity_counts=entity_counts,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            validation_count_rate=validation_count_rate,
            test_entity_counts=test_entity_counts,
            test_count_rate=test_count_rate,
            validation_combinations=validation_combinations,
            validation_space_rate_range=validation_space_rate_range,
            validation_combination_rate=validation_combination_rate,
            test_combinations=test_combinations,
            test_space_rate_range=test_space_rate_range,
            test_combination_rate=test_combination_rate,
            max_provoke_collision_rate=max_provoke_collision_rate,
            reinforcement_range=(1, 1))

        relation_captioner = RelationCaptioner(
            reference_captioner=RegularTypeCaptioner(
                existing_attribute_rate=type_existing_attribute_rate,
                incorrect_distribution=type_incorrect_distribution),
            comparison_captioner=UniqueTypeCaptioner(),
            relations=relations,
            incorrect_distribution=relation_incorrect_distribution)
        if negation:
            relation_captioner = NegationRelationCaptioner(
                relation_captioner=relation_captioner)

        world_captioner = ExistentialCaptioner(
            restrictor_captioner=RegularTypeCaptioner(
                existing_attribute_rate=type_existing_attribute_rate,
                incorrect_distribution=type_incorrect_distribution),
            body_captioner=relation_captioner,
            incorrect_distribution=existential_incorrect_distribution)

        super(RelationalDataset,
              self).__init__(world_generator=world_generator,
                             world_captioner=world_captioner,
                             caption_size=caption_size,
                             vocabulary=vocabulary,
                             correct_ratio=correct_ratio,
                             train_correct_ratio=train_correct_ratio,
                             validation_correct_ratio=validation_correct_ratio,
                             test_correct_ratio=test_correct_ratio,
                             worlds_per_instance=worlds_per_instance,
                             captions_per_instance=captions_per_instance,
                             pixel_noise_stddev=pixel_noise_stddev,
                             caption_realizer=caption_realizer,
                             language=language)
コード例 #8
0
    def __init__(
        self,
        entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
        train_entity_counts=(5, 6, 7, 8, 9, 10, 11, 12, 14),
        validation_entity_counts=(13,),
        test_entity_counts=(15,),
        validation_combinations=(('square', 'red', 'solid'), ('triangle', 'green', 'solid'), ('circle', 'blue', 'solid')),
        test_combinations=(('rectangle', 'yellow', 'solid'), ('cross', 'magenta', 'solid'), ('ellipse', 'cyan', 'solid')),
        caption_size=19,
        vocabulary=('.', 'a', 'above', 'all', 'an', 'are', 'as', 'at', 'behind', 'below', 'bigger', 'biggest', 'blue', 'both', 'but', 'circle', 'circles', 'closer', 'closest', 'cross', 'crosses', 'cyan', 'darker', 'darkest', 'eight', 'ellipse', 'ellipses', 'exactly', 'farther', 'farthest', 'five', 'four', 'from', 'front', 'gray', 'green', 'half', 'in', 'is', 'least', 'left', 'leftmost', 'less', 'lighter', 'lightest', 'lowermost', 'magenta', 'many', 'more', 'most', 'not', 'of', 'one', 'pentagon', 'pentagons', 'rectangle', 'rectangles', 'red', 'right', 'rightmost', 'semicircle', 'semicircles', 'seven', 'shape', 'shapes', 'six', 'smaller', 'smallest', 'square', 'squares', 'than', 'the', 'three', 'to', 'topmost', 'triangle', 'triangles', 'twice', 'two', 'yellow', 'zero'),
        language=None
    ):

        world_generator = ReinforcedAttributesGenerator(
            reinforcement_range=(1, 3),
            entity_counts=entity_counts,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            test_entity_counts=test_entity_counts,
            validation_combinations=validation_combinations,
            test_combinations=test_combinations
        )

        body_captioner = CaptionerMixer(
            captioners=(
                AttributeTypeRelationCaptioner(
                    attribute_type_captioner=CaptionerMixer(
                        captioners=(
                            RegularAttributeCaptioner(),
                            RegularTypeCaptioner()
                        )
                    )
                ),
                RelationCaptioner(
                    reference_captioner=RegularTypeCaptioner(),
                    comparison_captioner=RegularTypeCaptioner()
                )
            ),
            distribution=[1, 2]
        )
        quantifier_captioner = QuantifierCaptioner(
            restrictor_captioner=RegularTypeCaptioner(
                hypernym_rate=1.0,
                logical_tautology_rate=1.0
            ),
            body_captioner=body_captioner,
            quantifiers=('count',)
        )
        number_bound_captioner = NumberBoundCaptioner(
            quantifier_captioner=quantifier_captioner
        )
        comparative_quantifier_captioner = ComparativeQuantifierCaptioner(
            restrictor_captioner=RegularTypeCaptioner(
                hypernym_rate=1.0
            ),
            comparison_captioner=RegularTypeCaptioner(
                hypernym_rate=1.0
            ),
            body_captioner=body_captioner
        )
        world_captioner = CaptionerMixer(
            captioners=(quantifier_captioner, number_bound_captioner, comparative_quantifier_captioner),
            distribution=[1, 1, 1]
        )

        super(QuantificationCount, self).__init__(
            world_generator=world_generator,
            world_captioner=world_captioner,
            caption_size=caption_size,
            vocabulary=vocabulary,
            language=language
        )
コード例 #9
0
    def __init__(
            self,
            world_size=64,
            world_color='black',
            shapes=('square', 'rectangle', 'triangle', 'pentagon', 'cross',
                    'circle', 'semicircle', 'ellipse'),
            colors=('red', 'green', 'blue', 'yellow', 'magenta', 'cyan',
                    'gray'),
            textures=('solid', ),
            rotation=True,
            size_range=(0.1, 0.25),
            distortion_range=(2.0, 3.0),
            shade_range=0.4,
            collision_tolerance=0.25,
            collision_shade_difference=0.5,
            boundary_tolerance=0.25,
            entity_counts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
            train_entity_counts=(1, 2, 4, 6, 7, 9, 11, 12, 14),
            validation_entity_counts=(3, 8, 13),
            validation_count_rate=0.5,
            test_entity_counts=(5, 10, 15),
            test_count_rate=0.5,
            validation_combinations=(('square', 'red',
                                      'solid'), ('triangle', 'green', 'solid'),
                                     ('circle', 'blue', 'solid')),
            validation_space_rate_range=(0.0, 1.0),
            validation_combination_rate=0.5,
            test_combinations=(('rectangle', 'yellow',
                                'solid'), ('cross', 'magenta', 'solid'),
                               ('ellipse', 'cyan', 'solid')),
            test_space_rate_range=(0.0, 1.0),
            test_combination_rate=0.5,
            max_provoke_collision_rate=0.33,
            reinforcement_range=(1, 3),
            caption_size=28,
            vocabulary=('.', 'a', 'above', 'all', 'an', 'and', 'are', 'as',
                        'at', 'behind', 'below', 'bigger', 'blue', 'but',
                        'circle', 'circles', 'closer', 'cross', 'crosses',
                        'cyan', 'darker', 'eight', 'either', 'ellipse',
                        'ellipses', 'exactly', 'farther', 'few', 'five',
                        'four', 'from', 'front', 'gray', 'green', 'half', 'in',
                        'is', 'least', 'left', 'less', 'lighter', 'magenta',
                        'many', 'more', 'most', 'no', 'none', 'not', 'of',
                        'one', 'or', 'pentagon', 'pentagons', 'quarter',
                        'quarters', 'rectangle', 'rectangles', 'red', 'right',
                        'semicircle', 'semicircles', 'seven', 'shape',
                        'shapes', 'six', 'smaller', 'square', 'squares',
                        'than', 'the', 'there', 'third', 'thirds', 'three',
                        'to', 'triangle', 'triangles', 'twice', 'two',
                        'yellow', 'zero'),
            correct_ratio=0.5,
            train_correct_ratio=None,
            validation_correct_ratio=None,
            test_correct_ratio=None,
            worlds_per_instance=1,
            captions_per_instance=1,
            pixel_noise_stddev=0.0,
            caption_realizer='dmrs',
            language=None):

        random_generator = RandomAttributesGenerator(
            world_size=world_size,
            world_color=world_color,
            shapes=shapes,
            colors=colors,
            textures=textures,
            rotation=rotation,
            size_range=size_range,
            distortion_range=distortion_range,
            shade_range=shade_range,
            collision_tolerance=collision_tolerance,
            collision_shade_difference=collision_shade_difference,
            boundary_tolerance=boundary_tolerance,
            entity_counts=entity_counts,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            test_entity_counts=test_entity_counts,
            validation_combinations=validation_combinations,
            test_combinations=test_combinations,
            max_provoke_collision_rate=max_provoke_collision_rate)
        reinforced_attributes_generator = ReinforcedAttributesGenerator(
            world_size=world_size,
            world_color=world_color,
            shapes=shapes,
            colors=colors,
            textures=textures,
            rotation=rotation,
            size_range=size_range,
            distortion_range=distortion_range,
            shade_range=shade_range,
            collision_tolerance=collision_tolerance,
            collision_shade_difference=collision_shade_difference,
            boundary_tolerance=boundary_tolerance,
            entity_counts=entity_counts,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            test_entity_counts=test_entity_counts,
            validation_combinations=validation_combinations,
            test_combinations=test_combinations,
            max_provoke_collision_rate=max_provoke_collision_rate,
            reinforcement_range=reinforcement_range)
        world_generator = GeneratorMixer(
            generators=(random_generator, reinforced_attributes_generator))

        body_captioner = AttributeTypeRelationCaptioner(
            attribute_type_captioner=CaptionerMixer(
                captioners=(RegularAttributeCaptioner(),
                            RegularTypeCaptioner())))
        existential_captioner = CaptionerMixer(captioners=(
            RegularTypeCaptioner(),
            ExistentialCaptioner(restrictor_captioner=RegularTypeCaptioner(
                hypernym_rate=1.0, logical_tautology_rate=1.0),
                                 body_captioner=body_captioner)))
        relation_captioner = ExistentialCaptioner(
            restrictor_captioner=RegularTypeCaptioner(),
            body_captioner=RelationCaptioner(
                reference_captioner=RegularTypeCaptioner(),
                comparison_captioner=RegularTypeCaptioner()))
        quantifier_captioner = QuantifierCaptioner(
            restrictor_captioner=RegularTypeCaptioner(
                hypernym_rate=1.0, logical_tautology_rate=1.0),
            body_captioner=body_captioner)
        number_bound_captioner = NumberBoundCaptioner(
            quantifier_captioner=quantifier_captioner)
        comparative_quantifier_captioner = ComparativeQuantifierCaptioner(
            restrictor_captioner=RegularTypeCaptioner(hypernym_rate=1.0),
            comparison_captioner=RegularTypeCaptioner(hypernym_rate=1.0),
            body_captioner=body_captioner)
        quantification_captioner = CaptionerMixer(
            captioners=(quantifier_captioner, number_bound_captioner,
                        comparative_quantifier_captioner),
            distribution=[2, 2, 1])
        world_captioner = CaptionerMixer(captioners=(
            ConjunctionCaptioner(captioners=(existential_captioner,
                                             relation_captioner,
                                             quantification_captioner)),
            DisjunctionCaptioner(captioners=(existential_captioner,
                                             relation_captioner,
                                             quantification_captioner))))

        super(Combination,
              self).__init__(world_generator=world_generator,
                             world_captioner=world_captioner,
                             caption_size=caption_size,
                             vocabulary=vocabulary,
                             correct_ratio=correct_ratio,
                             train_correct_ratio=train_correct_ratio,
                             validation_correct_ratio=validation_correct_ratio,
                             test_correct_ratio=test_correct_ratio,
                             worlds_per_instance=worlds_per_instance,
                             captions_per_instance=captions_per_instance,
                             pixel_noise_stddev=pixel_noise_stddev,
                             caption_realizer=caption_realizer,
                             language=language)
コード例 #10
0
ファイル: selection.py プロジェクト: HuiyuanXie/ShapeWorld
    def __init__(
            self,
            world_size=64,
            world_colors=('black', ),
            shapes=('square', 'rectangle', 'triangle', 'pentagon', 'cross',
                    'circle', 'semicircle', 'ellipse'),
            colors=('red', 'green', 'blue', 'yellow', 'magenta', 'cyan',
                    'gray'),
            textures=('solid', ),
            rotation=True,
            size_range=(0.1, 0.25),
            distortion_range=(2.0, 3.0),
            shade_range=0.4,
            collision_tolerance=0.25,
            collision_shade_difference=0.5,
            boundary_tolerance=None,
            entity_counts=(4, 5, 6, 7, 8, 9, 10),
            train_entity_counts=None,
            validation_entity_counts=None,
            test_entity_counts=None,
            validation_count_rate=0.5,
            test_count_rate=0.5,
            validation_combinations=None,
            test_combinations=None,
            validation_space_rate_range=(0.0, 1.0),
            test_space_rate_range=(0.0, 1.0),
            validation_combination_rate=0.5,
            test_combination_rate=0.5,
            max_provoke_collision_rate=0.33,
            allow_empty_scope=True,
            selectors=None,
            caption_size=14,
            vocabulary=('.', 'a', 'an', 'are', 'bigger', 'biggest', 'blue',
                        'circle', 'circles', 'closer', 'closest', 'cross',
                        'crosses', 'cyan', 'darker', 'darkest', 'ellipse',
                        'ellipses', 'farther', 'farthest', 'five', 'four',
                        'from', 'gray', 'green', 'is', 'left', 'leftmost',
                        'lighter', 'lightest', 'lower', 'lowermost', 'magenta',
                        'one', 'pentagon', 'pentagons', 'rectangle',
                        'rectangles', 'red', 'right', 'rightmost',
                        'semicircle', 'semicircles', 'shape', 'shapes',
                        'smaller', 'smallest', 'square', 'squares', 'the',
                        'three', 'to', 'triangle', 'triangles', 'two', 'upper',
                        'uppermost', 'yellow'),
            correct_ratio=0.5,
            train_correct_ratio=None,
            validation_correct_ratio=None,
            test_correct_ratio=None,
            worlds_per_instance=1,
            captions_per_instance=1,
            pixel_noise_stddev=None,
            caption_realizer='dmrs',
            language=None):

        world_generator = ReinforcedAttributesGenerator(
            world_size=world_size,
            world_colors=world_colors,
            shapes=shapes,
            colors=colors,
            textures=textures,
            rotation=rotation,
            size_range=size_range,
            distortion_range=distortion_range,
            shade_range=shade_range,
            collision_tolerance=collision_tolerance,
            collision_shade_difference=collision_shade_difference,
            boundary_tolerance=boundary_tolerance,
            entity_counts=entity_counts,
            train_entity_counts=train_entity_counts,
            validation_entity_counts=validation_entity_counts,
            validation_count_rate=validation_count_rate,
            test_entity_counts=test_entity_counts,
            test_count_rate=test_count_rate,
            validation_combinations=validation_combinations,
            validation_space_rate_range=validation_space_rate_range,
            validation_combination_rate=validation_combination_rate,
            test_combinations=test_combinations,
            test_space_rate_range=test_space_rate_range,
            test_combination_rate=test_combination_rate,
            max_provoke_collision_rate=max_provoke_collision_rate,
            reinforcement_range=(1, 1))

        scope_captioners = [RegularTypeCaptioner(hypernym_rate=1.0)]
        if allow_empty_scope:
            scope_captioners.append(EmptyTypeCaptioner())
        world_captioner = ExistentialCaptioner(
            restrictor_captioner=SelectorCaptioner(
                scope_captioner=CaptionerMixer(captioners=scope_captioners),
                comparison_captioner=UniqueTypeCaptioner(),
                selectors=selectors),
            body_captioner=AttributeTypeRelationCaptioner(
                attribute_type_captioner=CaptionerMixer(
                    captioners=(RegularAttributeCaptioner(),
                                RegularTypeCaptioner(hypernym_rate=0.0)))))

        super(SelectionDataset,
              self).__init__(world_generator=world_generator,
                             world_captioner=world_captioner,
                             caption_size=caption_size,
                             vocabulary=vocabulary,
                             correct_ratio=correct_ratio,
                             train_correct_ratio=train_correct_ratio,
                             validation_correct_ratio=validation_correct_ratio,
                             test_correct_ratio=test_correct_ratio,
                             worlds_per_instance=worlds_per_instance,
                             captions_per_instance=captions_per_instance,
                             pixel_noise_stddev=pixel_noise_stddev,
                             caption_realizer=caption_realizer,
                             language=language)