Пример #1
0
 def __init__(self, in_channels, out_channels, indices):
     super(ParallelDeblock, self).__init__()
     indices_2d3 = indices[:, :indices.size(1) // 3 * 2]
     indices_d3 = indices[:, :indices.size(1) // 3]
     indices_1 = indices[:, 0:1]
     self.conv_2d3 = SpiralConv(in_channels, out_channels // 4, indices_2d3)
     self.conv_d3 = SpiralConv(in_channels, out_channels // 4, indices_d3)
     self.conv = SpiralConv(in_channels, out_channels // 2, indices)
     self.conv1 = SpiralConv(in_channels, out_channels, indices_1)
Пример #2
0
class SpiralDeblock(nn.Module):
    def __init__(self, in_channels, out_channels, indices):
        super(SpiralDeblock, self).__init__()
        self.conv = SpiralConv(in_channels, out_channels, indices)
        self.reset_parameters()

    def reset_parameters(self):
        self.conv.reset_parameters()

    def forward(self, x, up_transform):
        out = Pool(x, up_transform)
        out = F.elu(self.conv(out))
        return out
Пример #3
0
    def __init__(self, in_channels, out_channels, latent_channels,
                 spiral_indices, num_vert, up_transform, lam):
        super(AD, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.latent_channels = latent_channels
        self.spiral_indices = spiral_indices
        self.up_transform = up_transform
        self.num_vert = num_vert
        self.lam=lam
        self.type= 'AD'
        self.z_train = None
        self.z_test = None

        # decoder
        self.de_layers = nn.ModuleList()
        self.de_layers.append(
            nn.Linear(latent_channels, self.num_vert * out_channels[-1]))
        for idx in range(len(out_channels)):
            if idx == 0:
                self.de_layers.append(
                    SpiralDeblock(out_channels[-idx - 1],
                                  out_channels[-idx - 1],
                                  self.spiral_indices[-idx - 1]))
            else:
                self.de_layers.append(
                    SpiralDeblock(out_channels[-idx], out_channels[-idx - 1],
                                  self.spiral_indices[-idx - 1]))
        self.de_layers.append(
            SpiralConv(out_channels[0], in_channels, self.spiral_indices[0]))

        self.reset_parameters()
Пример #4
0
class Net(torch.nn.Module):
    def __init__(self, in_channels, num_classes, indices):
        super(Net, self).__init__()

        self.fc0 = nn.Linear(in_channels, 16)
        self.conv1 = SpiralConv(16, 32, indices)
        self.conv2 = SpiralConv(32, 64, indices)
        self.conv3 = SpiralConv(64, 128, indices)
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, num_classes)

        self.reset_parameters()

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.conv3.reset_parameters()
        nn.init.xavier_uniform_(self.fc0.weight, gain=1)
        nn.init.xavier_uniform_(self.fc1.weight, gain=1)
        nn.init.xavier_uniform_(self.fc2.weight, gain=1)
        nn.init.constant_(self.fc0.bias, 0)
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.constant_(self.fc2.bias, 0)

    def forward(self, x):
        x = F.elu(self.fc0(x))
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = F.elu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
Пример #5
0
    def __init__(self, in_channels, out_channels, latent_channels,
                 spiral_indices, down_transform, up_transform, lam):
        super(VAE, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.latent_channels = latent_channels
        self.spiral_indices = spiral_indices
        self.down_transform = down_transform
        self.up_transform = up_transform
        self.num_vert = self.down_transform[-1].size(0)
        self.lam = lam
        self.type = 'VAE'


        # encoder
        self.en_layers = nn.ModuleList()
        for idx in range(len(out_channels)):
            if idx == 0:
                self.en_layers.append(
                    SpiralEnblock(in_channels, out_channels[idx],
                                  self.spiral_indices[idx]))
            else:
                self.en_layers.append(
                    SpiralEnblock(out_channels[idx - 1], out_channels[idx],
                                  self.spiral_indices[idx]))

        self.en_mu = nn.Linear(self.num_vert * out_channels[-1], latent_channels)
        self.en_logvar = nn.Linear(self.num_vert * out_channels[-1], latent_channels)

        # decoder
        self.de_layers = nn.ModuleList()
        self.de_layers.append(
            nn.Linear(latent_channels, self.num_vert * out_channels[-1]))
        for idx in range(len(out_channels)):
            if idx == 0:
                self.de_layers.append(
                    SpiralDeblock(out_channels[-idx - 1],
                                  out_channels[-idx - 1],
                                  self.spiral_indices[-idx - 1]))
            else:
                self.de_layers.append(
                    SpiralDeblock(out_channels[-idx], out_channels[-idx - 1],
                                  self.spiral_indices[-idx - 1]))
        self.de_layers.append(
            SpiralConv(out_channels[0], in_channels, self.spiral_indices[0]))

        self.reset_parameters()
Пример #6
0
    def __init__(self, in_channels, num_classes, indices):
        super(Net, self).__init__()

        self.fc0 = nn.Linear(in_channels, 16)
        self.conv1 = SpiralConv(16, 32, indices)
        self.conv2 = SpiralConv(32, 64, indices)
        self.conv3 = SpiralConv(64, 128, indices)
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, num_classes)

        self.reset_parameters()
Пример #7
0
 def __init__(self, in_channels, out_channels, indices):
     super(SpiralDeblock, self).__init__()
     self.conv = SpiralConv(in_channels, out_channels, indices)
     self.reset_parameters()