Example #1
0
def make_scae(config):
    if isinstance(config, dict):
        config = Namespace(**config)

    cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder)

    part_encoder = CapsuleImageEncoder(encoder=cnn_encoder,
                                       **config.pcae_encoder)

    template_generator = TemplateGenerator(**config.pcae_template_generator)
    part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder)

    obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer)

    obj_age_regressor = nn_ext.MLP_regression(**config.obj_age_regressor)

    obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule)
    obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule)

    scae = SCAE(part_encoder=part_encoder,
                template_generator=template_generator,
                part_decoder=part_decoder,
                obj_encoder=obj_encoder,
                obj_decoder=obj_decoder,
                obj_age_regressor=obj_age_regressor,
                **config.scae)

    return scae
Example #2
0
def make_scae(model_params: dict):
    config = Namespace(**prepare_model_params(**model_params))

    cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder)
    part_encoder = CapsuleImageEncoder(
        encoder=cnn_encoder,
        **config.pcae_encoder
    )

    template_generator = TemplateGenerator(**config.pcae_template_generator)
    part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder)

    obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer)

    obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule)
    obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule)

    scae = SCAE(
        part_encoder=part_encoder,
        template_generator=template_generator,
        part_decoder=part_decoder,
        obj_encoder=obj_encoder,
        obj_decoder=obj_decoder,
        **config.scae
    )

    return scae
Example #3
0
    def test_scae(self):
        config = Namespace(**factory.prepare_model_params(**model_params))

        cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder)
        part_encoder = CapsuleImageEncoder(encoder=cnn_encoder,
                                           **config.pcae_encoder)

        template_generator = TemplateGenerator(
            **config.pcae_template_generator)
        part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder)

        obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer)

        obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule)
        obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule)

        scae = SCAE(part_encoder=part_encoder,
                    template_generator=template_generator,
                    part_decoder=part_decoder,
                    obj_encoder=obj_encoder,
                    obj_decoder=obj_decoder,
                    **config.scae)

        with torch.no_grad():
            batch_size = 24
            image = torch.rand(batch_size, *config.image_shape)
            label = torch.randint(0, config.n_classes, (batch_size, ))
            reconstruction_target = image

            res = scae(image=image)
            loss = scae.loss(res, reconstruction_target, label)
            accuracy = scae.calculate_accuracy(res, label)
Example #4
0
    def test_capsule_likelihood(self):
        capsule_layer_config = AttrDict(n_caps=32,
                                        dim_feature=256,
                                        n_votes=40,
                                        dim_caps=32,
                                        hidden_sizes=(128, ),
                                        learn_vote_scale=True,
                                        allow_deformations=True,
                                        noise_type='uniform',
                                        noise_scale=4.,
                                        similarity_transform=False,
                                        caps_dropout_rate=0.0)
        capsule_layer = CapsuleLayer(**capsule_layer_config)
        capsule_obj_decoder = CapsuleObjectDecoder(capsule_layer)

        B = 24
        O = capsule_layer_config.n_caps
        D = capsule_layer_config.dim_feature
        V = capsule_layer_config.n_votes
        H = capsule_layer_config.dim_caps
        P = 6

        h = torch.rand(B, O, D)
        x = torch.rand(B, V, P)
        presence = torch.rand(B, V)

        with torch.no_grad():
            result = capsule_obj_decoder(h, x, presence)

        self.assertTrue(result.vote.shape == (B, O, V, P))
        self.assertTrue(result.scale.shape == (B, O, V))
        self.assertTrue(result.vote_presence.shape == (B, O, V))
        self.assertTrue(result.presence_logit_per_caps.shape == (B, O, 1))
        self.assertTrue(result.presence_logit_per_vote.shape == (B, O, V))
        self.assertTrue(result.cpr_dynamic_reg_loss.shape == tuple())
        self.assertTrue(result.log_prob.shape == tuple())
        self.assertTrue(result.vote_presence.shape == (B, O, V))
        self.assertTrue(result.winner.shape == (B, V, P))
        self.assertTrue(result.winner_presence.shape == (B, V))
        self.assertTrue(result.soft_winner.shape == (B, V, P))
        self.assertTrue(result.soft_winner_presence.shape == (B, V))
        self.assertTrue(result.posterior_mixing_prob.shape == (B, O, V))
        self.assertTrue(result.mixing_logit.shape == (B, O + 1, V))
        self.assertTrue(result.mixing_log_prob.shape == (B, O + 1, V))
        self.assertTrue(result.caps_presence.shape == (B, O))
Example #5
0
    def test_capsule_layer(self):
        capsule_layer_config = AttrDict(n_caps=32,
                                        dim_feature=256,
                                        n_votes=40,
                                        dim_caps=32,
                                        hidden_sizes=(128, ),
                                        learn_vote_scale=True,
                                        allow_deformations=True,
                                        noise_type='uniform',
                                        noise_scale=4.,
                                        similarity_transform=False,
                                        caps_dropout_rate=0.0)

        B = 24
        O = capsule_layer_config.n_caps
        F = capsule_layer_config.dim_feature
        V = capsule_layer_config.n_votes
        H = capsule_layer_config.dim_caps

        feature = torch.rand(B, O, F)
        parent_transform = None
        parent_presence = None

        capsule_layer = CapsuleLayer(**capsule_layer_config)

        with torch.no_grad():
            result = capsule_layer(feature,
                                   parent_presence=parent_presence,
                                   parent_transform=parent_transform)

        self.assertTrue(result.vote.shape == (B, O, V, 3, 3))
        self.assertTrue(result.scale.shape == (B, O, V))
        self.assertTrue(result.vote_presence.shape == (B, O, V))
        self.assertTrue(result.presence_logit_per_caps.shape == (B, O, 1))
        self.assertTrue(result.presence_logit_per_vote.shape == (B, O, V))
        self.assertTrue(result.cpr_dynamic_reg_loss.shape == tuple())