def make_scae(config): if isinstance(config, dict): config = Namespace(**config) cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder) part_encoder = CapsuleImageEncoder(encoder=cnn_encoder, **config.pcae_encoder) template_generator = TemplateGenerator(**config.pcae_template_generator) part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder) obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer) obj_age_regressor = nn_ext.MLP_regression(**config.obj_age_regressor) obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule) obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule) scae = SCAE(part_encoder=part_encoder, template_generator=template_generator, part_decoder=part_decoder, obj_encoder=obj_encoder, obj_decoder=obj_decoder, obj_age_regressor=obj_age_regressor, **config.scae) return scae
def test_scae(self): config = Namespace(**factory.prepare_model_params(**model_params)) cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder) part_encoder = CapsuleImageEncoder(encoder=cnn_encoder, **config.pcae_encoder) template_generator = TemplateGenerator( **config.pcae_template_generator) part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder) obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer) obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule) obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule) scae = SCAE(part_encoder=part_encoder, template_generator=template_generator, part_decoder=part_decoder, obj_encoder=obj_encoder, obj_decoder=obj_decoder, **config.scae) with torch.no_grad(): batch_size = 24 image = torch.rand(batch_size, *config.image_shape) label = torch.randint(0, config.n_classes, (batch_size, )) reconstruction_target = image res = scae(image=image) loss = scae.loss(res, reconstruction_target, label) accuracy = scae.calculate_accuracy(res, label)
def make_scae(model_params: dict): config = Namespace(**prepare_model_params(**model_params)) cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder) part_encoder = CapsuleImageEncoder( encoder=cnn_encoder, **config.pcae_encoder ) template_generator = TemplateGenerator(**config.pcae_template_generator) part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder) obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer) obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule) obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule) scae = SCAE( part_encoder=part_encoder, template_generator=template_generator, part_decoder=part_decoder, obj_encoder=obj_encoder, obj_decoder=obj_decoder, **config.scae ) return scae
def test_set_transformer_with_isab(self): B = 32 N = 40 n_heads = 1 dim_in = 16 dim_hidden = 16 dim_out = 256 n_outputs = 10 n_layers = 3 x = torch.rand(B, N, dim_in) presence = None with torch.no_grad(): st = SetTransformer(dim_in, dim_hidden, dim_out, n_outputs, n_layers, n_heads, layer_norm=True, n_inducing_points=20) out = st(x, presence) self.assertTrue(out.shape == (B, n_outputs, dim_out))