def make_scae(config): if isinstance(config, dict): config = Namespace(**config) cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder) part_encoder = CapsuleImageEncoder(encoder=cnn_encoder, **config.pcae_encoder) template_generator = TemplateGenerator(**config.pcae_template_generator) part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder) obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer) obj_age_regressor = nn_ext.MLP_regression(**config.obj_age_regressor) obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule) obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule) scae = SCAE(part_encoder=part_encoder, template_generator=template_generator, part_decoder=part_decoder, obj_encoder=obj_encoder, obj_decoder=obj_decoder, obj_age_regressor=obj_age_regressor, **config.scae) return scae
def make_scae(model_params: dict): config = Namespace(**prepare_model_params(**model_params)) cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder) part_encoder = CapsuleImageEncoder( encoder=cnn_encoder, **config.pcae_encoder ) template_generator = TemplateGenerator(**config.pcae_template_generator) part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder) obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer) obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule) obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule) scae = SCAE( part_encoder=part_encoder, template_generator=template_generator, part_decoder=part_decoder, obj_encoder=obj_encoder, obj_decoder=obj_decoder, **config.scae ) return scae
def test_scae(self): config = Namespace(**factory.prepare_model_params(**model_params)) cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder) part_encoder = CapsuleImageEncoder(encoder=cnn_encoder, **config.pcae_encoder) template_generator = TemplateGenerator( **config.pcae_template_generator) part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder) obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer) obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule) obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule) scae = SCAE(part_encoder=part_encoder, template_generator=template_generator, part_decoder=part_decoder, obj_encoder=obj_encoder, obj_decoder=obj_decoder, **config.scae) with torch.no_grad(): batch_size = 24 image = torch.rand(batch_size, *config.image_shape) label = torch.randint(0, config.n_classes, (batch_size, )) reconstruction_target = image res = scae(image=image) loss = scae.loss(res, reconstruction_target, label) accuracy = scae.calculate_accuracy(res, label)
def test_capsule_likelihood(self): capsule_layer_config = AttrDict(n_caps=32, dim_feature=256, n_votes=40, dim_caps=32, hidden_sizes=(128, ), learn_vote_scale=True, allow_deformations=True, noise_type='uniform', noise_scale=4., similarity_transform=False, caps_dropout_rate=0.0) capsule_layer = CapsuleLayer(**capsule_layer_config) capsule_obj_decoder = CapsuleObjectDecoder(capsule_layer) B = 24 O = capsule_layer_config.n_caps D = capsule_layer_config.dim_feature V = capsule_layer_config.n_votes H = capsule_layer_config.dim_caps P = 6 h = torch.rand(B, O, D) x = torch.rand(B, V, P) presence = torch.rand(B, V) with torch.no_grad(): result = capsule_obj_decoder(h, x, presence) self.assertTrue(result.vote.shape == (B, O, V, P)) self.assertTrue(result.scale.shape == (B, O, V)) self.assertTrue(result.vote_presence.shape == (B, O, V)) self.assertTrue(result.presence_logit_per_caps.shape == (B, O, 1)) self.assertTrue(result.presence_logit_per_vote.shape == (B, O, V)) self.assertTrue(result.cpr_dynamic_reg_loss.shape == tuple()) self.assertTrue(result.log_prob.shape == tuple()) self.assertTrue(result.vote_presence.shape == (B, O, V)) self.assertTrue(result.winner.shape == (B, V, P)) self.assertTrue(result.winner_presence.shape == (B, V)) self.assertTrue(result.soft_winner.shape == (B, V, P)) self.assertTrue(result.soft_winner_presence.shape == (B, V)) self.assertTrue(result.posterior_mixing_prob.shape == (B, O, V)) self.assertTrue(result.mixing_logit.shape == (B, O + 1, V)) self.assertTrue(result.mixing_log_prob.shape == (B, O + 1, V)) self.assertTrue(result.caps_presence.shape == (B, O))
def test_capsule_layer(self): capsule_layer_config = AttrDict(n_caps=32, dim_feature=256, n_votes=40, dim_caps=32, hidden_sizes=(128, ), learn_vote_scale=True, allow_deformations=True, noise_type='uniform', noise_scale=4., similarity_transform=False, caps_dropout_rate=0.0) B = 24 O = capsule_layer_config.n_caps F = capsule_layer_config.dim_feature V = capsule_layer_config.n_votes H = capsule_layer_config.dim_caps feature = torch.rand(B, O, F) parent_transform = None parent_presence = None capsule_layer = CapsuleLayer(**capsule_layer_config) with torch.no_grad(): result = capsule_layer(feature, parent_presence=parent_presence, parent_transform=parent_transform) self.assertTrue(result.vote.shape == (B, O, V, 3, 3)) self.assertTrue(result.scale.shape == (B, O, V)) self.assertTrue(result.vote_presence.shape == (B, O, V)) self.assertTrue(result.presence_logit_per_caps.shape == (B, O, 1)) self.assertTrue(result.presence_logit_per_vote.shape == (B, O, V)) self.assertTrue(result.cpr_dynamic_reg_loss.shape == tuple())