Exemple #1
0
 def __init__(self, args):
     super(GeneratorED, self).__init__()
     self.args = args
     self.ch = args.gf_dim
     self.bottom_width = args.bottom_width
     # self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * args.gf_dim)
     self.cell1 = Cell(args.gf_dim, args.gf_dim, num_skip_in=0)
     self.cell2 = Cell(args.gf_dim, args.gf_dim, num_skip_in=1)
     self.cell3 = Cell(args.gf_dim, args.gf_dim, num_skip_in=2)
     self.to_rgb = nn.Sequential(nn.BatchNorm2d(args.gf_dim), nn.ReLU(),
                                 nn.Conv2d(args.gf_dim, 3, 3, 1, 1),
                                 nn.Tanh())
Exemple #2
0
class Generator(nn.Module):
    def __init__(self, args):
        super(Generator, self).__init__()
        self.args = args
        self.ch = args.gf_dim
        self.bottom_width = args.bottom_width
        self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * args.gf_dim)
        self.cell1 = Cell(args.gf_dim, args.gf_dim, num_skip_in=0)
        self.cell2 = Cell(args.gf_dim, args.gf_dim, num_skip_in=1)
        self.cell3 = Cell(args.gf_dim, args.gf_dim, num_skip_in=2)
        self.cell4 = Cell(args.gf_dim, args.gf_dim, num_skip_in=3)
        self.to_rgb = nn.Sequential(
            nn.BatchNorm2d(args.gf_dim),
            nn.ReLU(inplace=False),
            nn.Conv2d(args.gf_dim, 3, 3, 1, 1),
            nn.Tanh()
        )

    def set_arch(self, arch_id, cur_stage):
        if not isinstance(arch_id, list):
            arch_id = arch_id.to('cpu').numpy().tolist()
        arch_id = [int(x) for x in arch_id]
        self.cur_stage = cur_stage
        arch_stage1 = arch_id[:4]
        self.cell1.set_arch(conv_id=arch_stage1[0], norm_id=arch_stage1[1], up_id=arch_stage1[2],
                            short_cut_id=arch_stage1[3], skip_ins=[])
        if cur_stage >= 1:
            arch_stage2 = arch_id[4:9]
            self.cell2.set_arch(conv_id=arch_stage2[0], norm_id=arch_stage2[1], up_id=arch_stage2[2],
                                short_cut_id=arch_stage2[3], skip_ins=arch_stage2[4])

        if cur_stage >= 2:
            arch_stage3 = arch_id[9:14]
            self.cell3.set_arch(conv_id=arch_stage3[0], norm_id=arch_stage3[1], up_id=arch_stage3[2],
                                short_cut_id=arch_stage3[3], skip_ins=arch_stage3[4])
        
        if cur_stage == 3:
            arch_stage4 = arch_id[14:]
            self.cell4.set_arch(conv_id=arch_stage4[0], norm_id=arch_stage4[1], up_id=arch_stage4[2],
                                short_cut_id=arch_stage4[3], skip_ins=arch_stage4[4])

    def forward(self, z):
        h = self.l1(z).view(-1, self.ch, self.bottom_width, self.bottom_width)
        h1_skip_out, h1 = self.cell1(h)
        if self.cur_stage == 0:
            return self.to_rgb(h1)
        h2_skip_out, h2 = self.cell2(h1, (h1_skip_out,))
        if self.cur_stage == 1:
            return self.to_rgb(h2)
        h3_skip_out, h3 = self.cell3(h2, (h1_skip_out, h2_skip_out))
        if self.cur_stage == 2:
            return self.to_rgb(h3)
        _, h4 = self.cell4(h3, (h1_skip_out, h2_skip_out, h3_skip_out))
        if self.cur_stage == 3:
            return self.to_rgb(h4)
Exemple #3
0
class Generator(nn.Module):
    def __init__(self, args):
        super(Generator, self).__init__()
        self.args = args
        self.ch = args.gf_dim
        self.bottom_width = args.bottom_width
        self.l1 = nn.Linear(args.latent_dim,
                            (self.bottom_width**2) * args.gf_dim)
        self.cell1 = Cell(args.gf_dim, args.gf_dim, num_skip_in=0)
        self.cell2 = Cell(args.gf_dim, args.gf_dim, num_skip_in=1)
        self.cell3 = Cell(args.gf_dim, args.gf_dim, num_skip_in=2)
        self.to_rgb = nn.Sequential(nn.BatchNorm2d(args.gf_dim), nn.ReLU(),
                                    nn.Conv2d(args.gf_dim, 3, 3, 1, 1),
                                    nn.Tanh())

    def set_stage(self, cur_stage):
        self.cur_stage = cur_stage

    def set_arch(self, arch_id, cur_stage):
        if not isinstance(arch_id, list):
            arch_id = arch_id.to('cpu').numpy().tolist()
        arch_id = [int(x) for x in arch_id]
        self.cur_stage = cur_stage
        arch_stage1 = arch_id[:4]
        self.cell1.set_arch(conv_id=arch_stage1[0],
                            norm_id=arch_stage1[1],
                            up_id=arch_stage1[2],
                            short_cut_id=arch_stage1[3],
                            skip_ins=[])
        if cur_stage >= 1:
            arch_stage2 = arch_id[4:9]
            self.cell2.set_arch(conv_id=arch_stage2[0],
                                norm_id=arch_stage2[1],
                                up_id=arch_stage2[2],
                                short_cut_id=arch_stage2[3],
                                skip_ins=arch_stage2[4])

        if cur_stage == 2:
            arch_stage3 = arch_id[9:]
            self.cell3.set_arch(conv_id=arch_stage3[0],
                                norm_id=arch_stage3[1],
                                up_id=arch_stage3[2],
                                short_cut_id=arch_stage3[3],
                                skip_ins=arch_stage3[4])

    def forward(self, z, smooth=1., eval=False):
        h = self.l1(z).view(-1, self.ch, self.bottom_width, self.bottom_width)
        if self.cur_stage == -1:
            return self.to_rgb(h), F.interpolate(h,
                                                 size=(1, 1),
                                                 mode="bilinear").detach()
        h1_skip_out, h1 = self.cell1(h)
        if self.cur_stage == 0:
            if not eval:
                return self.to_rgb(h1)
            else:
                # [z_num, Channel, dsampled_h, dsampled_w] for the state, here we simply adopt 1 to reduce the size of the state.
                return self.to_rgb(h1), F.interpolate(
                    h1, size=(1, 1), mode="bilinear").detach()
        h2_skip_out, h2 = self.cell2(h1, (h1_skip_out, ))
        if self.cur_stage == 1:
            _, _, ht, wt = h2.size()
            if not eval:
                # smooth is disabled in the final submission (= 1.). We leave it here for you to play around.
                return smooth * self.to_rgb(h2) + (1. - smooth) * self.to_rgb(
                    F.interpolate(h1, size=(ht, wt), mode="bilinear"))
            else:
                return self.to_rgb(h2), F.interpolate(
                    h2, size=(1, 1), mode="bilinear").detach()

        _, h3 = self.cell3(h2, (h1_skip_out, h2_skip_out))
        if self.cur_stage == 2:
            _, _, ht, wt = h3.size()
            if not eval:
                return smooth * self.to_rgb(h3) + (1. - smooth) * self.to_rgb(
                    F.interpolate(h2, size=(ht, wt), mode="bilinear"))
            else:
                return self.to_rgb(h3), F.interpolate(
                    h3, size=(1, 1), mode="bilinear").detach()