示例#1
0
def make_scae(config):
    if isinstance(config, dict):
        config = Namespace(**config)

    cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder)

    part_encoder = CapsuleImageEncoder(encoder=cnn_encoder,
                                       **config.pcae_encoder)

    template_generator = TemplateGenerator(**config.pcae_template_generator)
    part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder)

    obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer)

    obj_age_regressor = nn_ext.MLP_regression(**config.obj_age_regressor)

    obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule)
    obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule)

    scae = SCAE(part_encoder=part_encoder,
                template_generator=template_generator,
                part_decoder=part_decoder,
                obj_encoder=obj_encoder,
                obj_decoder=obj_decoder,
                obj_age_regressor=obj_age_regressor,
                **config.scae)

    return scae
示例#2
0
    def test_scae(self):
        config = Namespace(**factory.prepare_model_params(**model_params))

        cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder)
        part_encoder = CapsuleImageEncoder(encoder=cnn_encoder,
                                           **config.pcae_encoder)

        template_generator = TemplateGenerator(
            **config.pcae_template_generator)
        part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder)

        obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer)

        obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule)
        obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule)

        scae = SCAE(part_encoder=part_encoder,
                    template_generator=template_generator,
                    part_decoder=part_decoder,
                    obj_encoder=obj_encoder,
                    obj_decoder=obj_decoder,
                    **config.scae)

        with torch.no_grad():
            batch_size = 24
            image = torch.rand(batch_size, *config.image_shape)
            label = torch.randint(0, config.n_classes, (batch_size, ))
            reconstruction_target = image

            res = scae(image=image)
            loss = scae.loss(res, reconstruction_target, label)
            accuracy = scae.calculate_accuracy(res, label)
示例#3
0
def make_scae(model_params: dict):
    config = Namespace(**prepare_model_params(**model_params))

    cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder)
    part_encoder = CapsuleImageEncoder(
        encoder=cnn_encoder,
        **config.pcae_encoder
    )

    template_generator = TemplateGenerator(**config.pcae_template_generator)
    part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder)

    obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer)

    obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule)
    obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule)

    scae = SCAE(
        part_encoder=part_encoder,
        template_generator=template_generator,
        part_decoder=part_decoder,
        obj_encoder=obj_encoder,
        obj_decoder=obj_decoder,
        **config.scae
    )

    return scae
    def test_set_transformer_with_isab(self):
        B = 32
        N = 40

        n_heads = 1
        dim_in = 16
        dim_hidden = 16
        dim_out = 256
        n_outputs = 10
        n_layers = 3

        x = torch.rand(B, N, dim_in)
        presence = None

        with torch.no_grad():
            st = SetTransformer(dim_in,
                                dim_hidden,
                                dim_out,
                                n_outputs,
                                n_layers,
                                n_heads,
                                layer_norm=True,
                                n_inducing_points=20)
            out = st(x, presence)

        self.assertTrue(out.shape == (B, n_outputs, dim_out))