Ejemplo n.º 1
0
    def define_module(self):
        from layers import DynamicFilterLayer
        ninput = self.motion_dim + self.content_dim
        ngf = self.gf_dim
        # TEXT.DIMENSION -> GAN.CONDITION_DIM
        self.ca_net = CA_NET()

        # -> ngf x 4 x 4
        self.fc = nn.Sequential(
            nn.Linear(ninput, int(ngf * 4 * 4 / 2), bias=False),
            nn.BatchNorm1d(ngf * 4 * 2), nn.ReLU(True))

        # ngf x 4 x 4 -> ngf/2 x 8 x 8
        self.upsample1 = upBlock(ngf, ngf // 2)
        # -> ngf/4 x 16 x 16
        self.upsample2 = upBlock(ngf // 2, ngf // 4)
        # -> ngf/8 x 32 x 32
        self.upsample3 = upBlock(ngf // 4, ngf // 8)
        # -> ngf/16 x 64 x 64
        self.upsample4 = upBlock(ngf // 8, ngf // 16)
        # -> 3 x 64 x 64
        self.img = nn.Sequential(conv3x3(ngf // 16, 3), nn.Tanh())

        self.filter_net = nn.Sequential(
            nn.Linear(self.content_dim, self.filter_size**2, bias=False),
            nn.BatchNorm1d(self.filter_size**2),
            #nn.Softmax()
        )

        self.image_net = nn.Sequential(
            nn.Linear(self.motion_dim, self.r_image_size**2, bias=False),
            nn.BatchNorm1d(self.r_image_size**2))

        self.dfn_layer = DynamicFilterLayer(
            (self.filter_size, self.filter_size, 1),
            pad=(self.filter_size // 2, self.filter_size // 2),
            grouping=False)

        self.downsamples = nn.Sequential(
            nn.Conv2d(
                1, ngf, 3, 2, 1, bias=False
            ),  #spectral_norm(nn.Conv2d(1, ngf, 3, 2, 1, bias=False)),
            nn.BatchNorm2d(ngf),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                ngf, ngf // 2, 4, 2, 1, bias=False
            ),  #spectral_norm(nn.Conv2d(ngf, ngf//2, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ngf // 2),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def define_module(self):
        from layers import DynamicFilterLayer1D as DynamicFilterLayer
        ninput = self.motion_dim + self.content_dim + self.image_size  # (365+124+124=613)
        ngf = self.gf_dim  # 128*8=1024

        self.ca_net = CA_NET()
        # -> ngf x 4 x 4

        self.filter_net = nn.Sequential(
            nn.Linear(self.content_dim,
                      self.filter_size * self.filter_num * self.out_num),
            nn.BatchNorm1d(self.filter_size * self.filter_num * self.out_num))

        self.image_net = nn.Sequential(
            nn.Linear(self.motion_dim, self.image_size * self.filter_num),
            nn.BatchNorm1d(self.image_size * self.filter_num), nn.Tanh())

        # For generate final image
        self.fc = nn.Sequential(nn.Linear(ninput, ngf * 4 * 4, bias=False),
                                nn.BatchNorm1d(ngf * 4 * 4), nn.ReLU(True))
        self.upsample1 = upBlock(ngf, ngf // 2)
        # -> ngf/4 x 16 x 16
        self.upsample2 = upBlock(ngf // 2, ngf // 4)
        # -> ngf/8 x 32 x 32
        self.upsample3 = upBlock(ngf // 4, ngf // 8)
        # -> ngf/16 x 64 x 64
        self.upsample4 = upBlock(ngf // 8, ngf // 16)
        # -> 3 x 64 x 64
        self.img = nn.Sequential(conv3x3(ngf // 16, 3), nn.Tanh())
        if self.use_segment:
            ngf_seg = self.gf_dim_seg

            self.seg_c = conv3x3(ngf_seg, ngf)
            self.seg_c1 = conv3x3(ngf_seg // 2, ngf // 2)
            # self.seg_c2 = conv3x3(ngf_seg//4, ngf//4)
            # self.seg_c3 = conv3x3(ngf_seg//8, ngf//8)
            # self.seg_c4 = conv3x3(ngf_seg//16, ngf//16)

            # For generate seg and img v4 and v5 and v6
            self.fc_seg = nn.Sequential(
                nn.Linear(ninput, ngf_seg * 4 * 4, bias=False),
                nn.BatchNorm1d(ngf_seg * 4 * 4), nn.ReLU(True))
            # ngf x 4 x 4 -> ngf/2 x 8 x 8
            self.upsample1_seg = upBlock(ngf_seg, ngf_seg // 2)
            # -> ngf/4 x 16 x 16
            self.upsample2_seg = upBlock(ngf_seg // 2, ngf_seg // 4)
            # -> ngf/8 x 32 x 32
            self.upsample3_seg = upBlock(ngf_seg // 4, ngf_seg // 8)
            # -> ngf/16 x 64 x 64
            self.upsample4_seg = upBlock(ngf_seg // 8, ngf_seg // 16)
            # -> 3 x 64 x 64
            self.img_seg = nn.Sequential(conv3x3(ngf_seg // 16, 1), nn.Tanh())

        self.m_net = nn.Sequential(nn.Linear(self.motion_dim, self.motion_dim),
                                   nn.BatchNorm1d(self.motion_dim))

        self.c_net = nn.Sequential(
            nn.Linear(self.content_dim, self.content_dim),
            nn.BatchNorm1d(self.content_dim))

        self.dfn_layer = DynamicFilterLayer(self.filter_size,
                                            pad=self.filter_size // 2)