示例#1
0
 def decoder():
     conv_layer = nn.ModuleList()
     conv_layer += conv_block(64, 4 * 4 * 1024, 1, a_func='relu')
     conv_layer.append(Reshape((1024, 4, 4)))  # 4x4x1024
     conv_layer += convTranspose_block(1024, 512, 4, 2, 1)  # 8x8x512
     conv_layer += convTranspose_block(512, 256, 4, 2, 1)  # 16x16x256
     conv_layer += convTranspose_block(256, 128, 4, 2, 1)  # 32x32x128
     conv_layer += convTranspose_block(128,
                                       3,
                                       4,
                                       2,
                                       1,
                                       bn=False,
                                       a_func='')  # 64x64x3
     conv_layer.append(nn.Tanh())
     return conv_layer
示例#2
0
    def __init__(self, args):
        super(classifier, self).__init__()

        K = args.cls_hiddens
        L = np.prod(args.input_size)
        n_classes = args.n_classes
        self.args = args

        activation = nn.ReLU()
        self.layer = nn.Sequential(Reshape([-1]),
                                   GatedDense(L, K, activation=activation),
                                   nn.Dropout(p=0.2),
                                   GatedDense(K, n_classes, activation=None))

        # get gradient dimension:
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
示例#3
0
    def __init__(self, imgsz, z_dim):
        """

        :param imgsz:
        :param z_dim:
        """
        super(Decoder, self).__init__()

        mapsz = 4
        ch_next = z_dim
        print('Decoder:', [z_dim], '=>', [2, ch_next, mapsz, mapsz], end='=>')

        # z: [b, z_dim] => [b, z_dim, 4, 4]
        layers = [
            # z_dim => z_dim * 4 * 4 => [z_dim, 4, 4] => [z_dim, 4, 4]
            nn.Linear(z_dim, z_dim * mapsz * mapsz),
            nn.BatchNorm1d(z_dim * mapsz * mapsz),
            nn.ReLU(inplace=True),
            Reshape(z_dim, mapsz, mapsz),
            ResBlk([3, 3], [z_dim, z_dim, z_dim])
        ]

        # scale imgsz up while keeping channel untouched
        # [b, z_dim, 4, 4] => [b, z_dim, 8, 8] => [b, z_dim, 16, 16]
        for i in range(2):
            layers.extend([
                nn.Upsample(scale_factor=2),
                ResBlk([3, 3], [ch_next, ch_next, ch_next])
            ])
            mapsz = mapsz * 2

            # for print
            tmp = torch.randn(2, z_dim)
            net = nn.Sequential(*layers)
            out = net(tmp)
            print(list(out.shape), end='=>')
            del net

        # scale imgsz up and scale imgc down
        # [b, z_dim, 16, 16] => [z_dim//2, 32, 32] => [z_dim//4, 64, 64] => [z_dim//8, 128, 128]
        # => [z_dim//16, 256, 256] => [z_dim//32, 512, 512]
        while mapsz < imgsz // 2:
            ch_cur = ch_next
            ch_next = ch_next // 2 if ch_next >= 32 else ch_next  # set mininum ch=16
            layers.extend([
                # [2, 32, 32, 32] => [2, 32, 64, 64]
                nn.Upsample(scale_factor=2),
                # => [2, 16, 64, 64]
                ResBlk([1, 3, 3], [ch_cur, ch_next, ch_next, ch_next])
            ])
            mapsz = mapsz * 2

            # for print
            tmp = torch.randn(2, z_dim)
            net = nn.Sequential(*layers)
            out = net(tmp)
            print(list(out.shape), end='=>')
            del net

        # [b, ch_next, 512, 512] => [b, 3, 1024, 1024]
        layers.extend([
            nn.Upsample(scale_factor=2),
            ResBlk([3, 3], [ch_next, ch_next, ch_next]),
            nn.Conv2d(ch_next, 3, kernel_size=5, stride=1, padding=2),
            # sigmoid / tanh
        ])

        self.net = nn.Sequential(*layers)

        # for print
        tmp = torch.randn(2, z_dim)
        out = self.net(tmp)
        print(list(out.shape))
示例#4
0
    def create_encoder(self):
        """
        Helper function to create the elemental blocks for the encoder. Creates a gated convnet encoder.
        the encoder expects data as input of shape (batch_size, num_channels, width, height).
        """

        if self.input_type == 'binary':

            if self.args.gen_architecture == 'GatedConv':
                q_z_nn = nn.Sequential(
                    GatedConv2d(self.input_size[0], 32, 5, 1, 2),
                    nn.Dropout(self.args.dropout),
                    GatedConv2d(32, 32, 5, 2, 2),
                    nn.Dropout(self.args.dropout),
                    GatedConv2d(32, 64, 5, 1, 2),
                    nn.Dropout(self.args.dropout),
                    GatedConv2d(64, 64, 5, 2, 2),
                    nn.Dropout(self.args.dropout),
                    GatedConv2d(64, 64, 5, 1, 2),
                    nn.Dropout(self.args.dropout),
                    GatedConv2d(64, self.gen_hiddens, self.last_kernel_size, 1,
                                0),
                )
                assert self.args.gen_depth == 6

            elif self.args.gen_architecture == 'MLP':
                q_z_nn = [
                    Reshape([-1]),
                    nn.Linear(np.prod(self.args.input_size), self.gen_hiddens),
                    nn.ReLU(True),
                    nn.Dropout(self.args.dropout),
                ]
                for i in range(1, self.args.gen_depth):
                    q_z_nn += [
                        nn.Linear(self.args.gen_hiddens,
                                  self.args.gen_hiddens),
                        nn.ReLU(True),
                        nn.Dropout(self.args.dropout),
                    ]
                q_z_nn = nn.Sequential(*q_z_nn)

            q_z_mean = nn.Linear(self.gen_hiddens, self.z_size)
            q_z_var = nn.Sequential(
                nn.Linear(self.gen_hiddens, self.z_size),
                nn.Softplus(),
            )
            return q_z_nn, q_z_mean, q_z_var

        #TODO(add log_logistic loss for continuous)
        elif self.input_type in ['multinomial', 'continuous']:
            act = None

            q_z_nn = nn.Sequential(
                GatedConv2d(self.input_size[0], 32, 5, 1, 2, activation=act),
                GatedConv2d(32, 32, 5, 2, 2, activation=act),
                GatedConv2d(32, 64, 5, 1, 2, activation=act),
                GatedConv2d(64, 64, 5, 2, 2, activation=act),
                GatedConv2d(64, 64, 5, 1, 2, activation=act),
                GatedConv2d(64,
                            256,
                            self.last_kernel_size,
                            1,
                            0,
                            activation=act))
            q_z_mean = nn.Linear(256, self.z_size)
            q_z_var = nn.Sequential(nn.Linear(256, self.z_size), nn.Softplus(),
                                    nn.Hardtanh(min_val=0.01, max_val=7.))
            return q_z_nn, q_z_mean, q_z_var
示例#5
0
    def create_decoder(self):
        """
        Helper function to create the elemental blocks for the decoder. Creates a gated convnet decoder.
        """

        # TODO(why the hell would num_classes be 256?)
        #num_classes = 256
        num_classes = 1

        if self.input_type == 'binary':
            if self.args.gen_architecture == 'GatedConv':
                p_x_nn = nn.Sequential(
                    Reshape([self.args.z_size, 1, 1]),
                    GatedConvTranspose2d(self.z_size, 64,
                                         self.last_kernel_size, 1, 0),
                    GatedConvTranspose2d(64, 64, 5, 1, 2),
                    GatedConvTranspose2d(64, 32, 5, 2, 2, 1),
                    GatedConvTranspose2d(32, 32, 5, 1, 2),
                    GatedConvTranspose2d(32, 32, 5, 2, 2, 1),
                    GatedConvTranspose2d(32, 32, 5, 1, 2))
                p_x_mean = nn.Sequential(
                    nn.Conv2d(32, self.input_size[0], 1, 1, 0), nn.Sigmoid())
            elif self.args.gen_architecture == 'MLP':
                p_x_nn = [
                    nn.Linear(self.z_size, self.gen_hiddens),
                    nn.ReLU(True),
                    nn.Dropout(self.args.dropout),
                ]
                for i in range(1, self.args.gen_depth):
                    p_x_nn += [
                        nn.Linear(self.args.gen_hiddens,
                                  self.args.gen_hiddens),
                        nn.ReLU(True),
                        nn.Dropout(self.args.dropout),
                    ]
                p_x_nn = nn.Sequential(*p_x_nn)

                p_x_mean = nn.Sequential(
                    nn.Linear(self.args.gen_hiddens,
                              np.prod(self.args.input_size)), nn.Sigmoid(),
                    Reshape(self.args.input_size))
            return p_x_nn, p_x_mean

        #TODO(add log_logistic loss for continuous)
        elif self.input_type in ['multinomial', 'continuous']:
            act = None
            p_x_nn = nn.Sequential(
                Reshape([self.args.z_size, 1, 1]),  # xuji added
                GatedConvTranspose2d(self.z_size,
                                     64,
                                     self.last_kernel_size,
                                     1,
                                     0,
                                     activation=act),
                GatedConvTranspose2d(64, 64, 5, 1, 2, activation=act),
                GatedConvTranspose2d(64, 32, 5, 2, 2, 1, activation=act),
                GatedConvTranspose2d(32, 32, 5, 1, 2, activation=act),
                GatedConvTranspose2d(32, 32, 5, 2, 2, 1, activation=act),
                GatedConvTranspose2d(32, 32, 5, 1, 2, activation=act))

            p_x_mean = nn.Sequential(
                nn.Conv2d(32, 256, 5, 1, 2),
                nn.Conv2d(256, self.input_size[0] * num_classes, 1, 1, 0),
                # output shape: batch_size, num_channels * 1, pixel_width, pixel_height
            )

            return p_x_nn, p_x_mean

        else:
            raise ValueError('invalid input type!!')