def create_capsnet(input_shape, n_class, out_dim, num_routing): # Create CapsNet x = layers.Input(shape=input_shape) conv1 = layers.Conv2D(filters=64, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x) primary_caps = PrimaryCaps(layer_input=conv1, name='primary_caps', dim_capsule=3, channels=2, kernel_size=9, strides=2) digit_caps = CapsuleLayer(num_capsule=n_class, dim_vector=out_dim, num_routing=num_routing)(primary_caps) out_caps = Length(name='capsnet')(digit_caps) # Create decoder y = layers.Input(shape=(n_class,)) masked_by_y = Mask()([digit_caps, y]) # The true label is used to mask the output of capsule layer for training masked = Mask()(digit_caps) # Mask using the capsule with maximal length for prediction # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=out_dim*n_class)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=input_shape, name='decoder_output')) # Models for training and evaluation (prediction) train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)]) eval_model = models.Model(x, [out_caps, decoder(masked)]) # manipulate model noise = layers.Input(shape=(n_class, out_dim)) noised_digit_caps = layers.Add()([digit_caps, noise]) masked_noised_y = Mask()([noised_digit_caps, y]) manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y)) return train_model, eval_model, manipulate_model
def __init__(self): super(CapsuleNet, self).__init__() self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1) self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32, kernel_size=9, stride=2) self.digit_capsules = CapsuleLayer(num_capsules=config.NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8, out_channels=16) self.decoder = nn.Sequential(nn.Linear(16 * config.NUM_CLASSES, 512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, 784), nn.Sigmoid())
def __init__(self, S_set): super(FSANet, self).__init__() self.ssr_layer = SSRLayer() self.num_capsule = S_set[0] # 3 self.dim_capsule = S_set[1] # 16 self.routings = S_set[2] # 2 self.num_primcaps = S_set[3] # 8*8*3 self.m_dim = S_set[4] # 5 # (-1,c,64) @ (-1,64,16) -> (-1,c,16) # self.w1 = nn.Parameter(torch.randn(1,64,16)) # (-1,16,c) @ (-1,c,3) -> (-1,16,3) # self.w2 = nn.Parameter(torch.randn(1,self.num_primcaps,3)) self.capsule_layer = CapsuleLayer(in_units=self.num_primcaps, in_channels=64, num_units=self.num_capsule, unit_size=self.dim_capsule) self.x_layer1 = nn.Sequential(nn.Conv2d(3, 16, 3, 1, 1), nn.BatchNorm2d(16), nn.ReLU(inplace=True), nn.AvgPool2d(2)) self.x_layer2 = nn.Sequential(nn.Conv2d(16, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.AvgPool2d(2)) self.x_layer3 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.AvgPool2d(2)) self.x_layer4 = nn.Sequential( nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), ) self.s_layer1 = nn.Sequential(nn.Conv2d(3, 16, 3, 1, 1), nn.BatchNorm2d(16), nn.Tanh(), nn.MaxPool2d(2)) self.s_layer2 = nn.Sequential(nn.Conv2d(16, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.Tanh(), nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.Tanh(), nn.MaxPool2d(2)) self.s_layer3 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.Tanh(), nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.Tanh(), nn.MaxPool2d(2)) self.s_layer4 = nn.Sequential( nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.Tanh(), nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.Tanh(), ) self.x_layer4_ = nn.Sequential(nn.Conv2d(128, 64, 1, 1, 0), nn.ReLU(inplace=True)) self.x_layer3_ = nn.Sequential(nn.Conv2d(64, 64, 1, 1, 0), nn.ReLU(inplace=True)) self.x_layer2_ = nn.Sequential(nn.Conv2d(32, 64, 1, 1, 0), nn.ReLU(inplace=True)) self.s_layer4_ = nn.Sequential(nn.Conv2d(128, 64, 1, 1, 0), nn.Tanh()) self.s_layer3_ = nn.Sequential(nn.Conv2d(64, 64, 1, 1, 0), nn.Tanh()) self.s_layer2_ = nn.Sequential(nn.Conv2d(32, 64, 1, 1, 0), nn.Tanh()) self.agvpool = nn.AvgPool2d(2) self.feat_preS1 = nn.Sequential(nn.Conv2d(64, 1, 1, 1, 0), nn.Sigmoid()) self.feat_preS2 = nn.Sequential(nn.Conv2d(64, 1, 1, 1, 0), nn.Sigmoid()) self.feat_preS3 = nn.Sequential(nn.Conv2d(64, 1, 1, 1, 0), nn.Sigmoid()) self.sr_matrix1 = nn.Sequential( nn.Linear(8 * 8, self.m_dim * 8 * 8 * 3), nn.Sigmoid()) self.sr_matrix2 = nn.Sequential( nn.Linear(8 * 8, self.m_dim * 8 * 8 * 3), nn.Sigmoid()) self.sr_matrix3 = nn.Sequential( nn.Linear(8 * 8, self.m_dim * 8 * 8 * 3), nn.Sigmoid()) self.SL_matrix = nn.Sequential( nn.Linear(8 * 8 * 3, int(self.num_primcaps / 3) * self.m_dim), nn.Sigmoid()) self.delta_s1 = nn.Sequential(nn.Linear(4, 3), nn.Tanh()) self.delta_s2 = nn.Sequential(nn.Linear(4, 3), nn.Tanh()) self.delta_s3 = nn.Sequential(nn.Linear(4, 3), nn.Tanh()) self.local_s1 = nn.Sequential(nn.Linear(4, 3), nn.Tanh()) self.local_s2 = nn.Sequential(nn.Linear(4, 3), nn.Tanh()) self.local_s3 = nn.Sequential(nn.Linear(4, 3), nn.Tanh()) self.pred_s1 = nn.Sequential(nn.Linear(8, 9), nn.ReLU(inplace=True)) self.pred_s2 = nn.Sequential(nn.Linear(8, 9), nn.ReLU(inplace=True)) self.pred_s3 = nn.Sequential(nn.Linear(8, 9), nn.ReLU(inplace=True)) self.domain_classifier = nn.Sequential(nn.Linear(21 * 64, 100), nn.BatchNorm1d(100), nn.ReLU(True), nn.Linear(100, 2), nn.LogSoftmax(dim=1))
import torch from attention import SetTransformer from capsule import CapsuleLayer # originally (32, 11, 2) INPUT = torch.ones((32, 11, 2)) st = SetTransformer(2) encoded = st.forward(INPUT) print(encoded.shape) decoder = CapsuleLayer(input_dims=32, n_caps=3, n_caps_dims=2, n_votes=4, n_caps_params=32, n_hiddens=128, learn_vote_scale=True, deformations=True, noise_type='uniform', noise_scale=4., similarity_transform=False) decoder.forward(encoded)