Esempio n. 1
0
    def __init__(
        self,
        config,
        outdir,
        modeldir,
        data_path,
        sketch_path,
    ):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"])
        print(self.dataset)

        gen = Generator(layers=model_config["generator"]["num_layers"],
                        attn_type=model_config["generator"]["attn_type"])
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])

        dis = Discriminator()
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.lossfunc = StyleAdaINLossCalculator
        self.visualizer = Visualizer(self.data_config["color_space"])
Esempio n. 2
0
    def __init__(self, config, outdir, modeldir, data_path, sketch_path,
                 ss_path):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, ss_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"])
        print(self.dataset)

        gen = Generator(model_config["generator"]["in_ch"],
                        num_layers=model_config["generator"]["num_layers"],
                        attn_type=model_config["generator"]["attn_type"],
                        guide=model_config["generator"]["guide"])
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])
        self.guide = model_config["generator"]["guide"]

        i_dis = Discriminator(model_config["image_dis"]["in_ch"],
                              model_config["image_dis"]["multi"])
        self.i_dis, self.i_dis_opt = self._setting_model_optim(
            i_dis, model_config["image_dis"])

        s_dis = Discriminator(model_config["surface_dis"]["in_ch"],
                              model_config["surface_dis"]["multi"])
        self.s_dis, self.s_dis_opt = self._setting_model_optim(
            s_dis, model_config["surface_dis"])

        t_dis = Discriminator(model_config["texture_dis"]["in_ch"],
                              model_config["texture_dis"]["multi"])
        self.t_dis, self.t_dis_opt = self._setting_model_optim(
            t_dis, model_config["texture_dis"])

        self.guided_filter = GuidedFilter(r=5, eps=2e-1)
        self.guided_filter.cuda()

        self.out_guided_filter = GuidedFilter(r=1, eps=1e-2)
        self.out_guided_filter.cuda()

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = WhiteBoxLossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
Esempio n. 3
0
    def __init__(self, config, outdir, modeldir, data_path, sketch_path,
                 ss_path):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, ss_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"])
        print(self.dataset)

        gen = Generator(model_config["generator"]["in_ch"],
                        base=model_config["generator"]["base"],
                        num_layers=model_config["generator"]["num_layers"],
                        up_layers=model_config["generator"]["up_layers"],
                        guide=model_config["generator"]["guide"],
                        resnext=model_config["generator"]["resnext"],
                        encoder_type=model_config["generator"]["encoder_type"])
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])
        self.guide = model_config["generator"]["guide"]

        dis = Discriminator(model_config["discriminator"]["in_ch"],
                            model_config["discriminator"]["multi"],
                            base=model_config["discriminator"]["base"],
                            sn=model_config["discriminator"]["sn"],
                            resnext=model_config["discriminator"]["resnext"],
                            patch=model_config["discriminator"]["patch"])
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False, layer="four")
        self.vgg.cuda()
        self.vgg.eval()

        self.out_filter = GuidedFilter(r=1, eps=1e-2)
        self.out_filter.cuda()

        self.lossfunc = LossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])

        self.scheduler_gen = torch.optim.lr_scheduler.ExponentialLR(
            self.gen_opt, self.train_config["gamma"])
        self.scheduler_dis = torch.optim.lr_scheduler.ExponentialLR(
            self.dis_opt, self.train_config["gamma"])
Esempio n. 4
0
    def __init__(
        self,
        config,
        outdir,
        modeldir,
        data_path,
        sketch_path,
    ):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir
        self.mask = self.train_config["mask"]

        self.dataset = IllustDataset(data_path, sketch_path,
                                     self.data_config["extension"],
                                     self.data_config["train_size"],
                                     self.data_config["valid_size"],
                                     self.data_config["color_space"],
                                     self.data_config["line_space"])
        print(self.dataset)

        if self.mask:
            in_ch = 6
        else:
            in_ch = 3

        loc_gen = LocalEnhancer(
            in_ch=in_ch,
            num_layers=model_config["local_enhancer"]["num_layers"])
        self.loc_gen, self.loc_gen_opt = self._setting_model_optim(
            loc_gen, model_config["local_enhancer"])

        glo_gen = GlobalGenerator(in_ch=in_ch)
        self.glo_gen, self.glo_gen_opt = self._setting_model_optim(
            glo_gen, model_config["global_generator"])

        dis = Discriminator(model_config["discriminator"]["in_ch"],
                            model_config["discriminator"]["multi"])
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = Pix2pixHDCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
Esempio n. 5
0
    def __init__(self, config, outdir, modeldir, data_path, sketch_path,
                 dist_path, pretrained_path):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, dist_path, self.data_config["anime_dir"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["scale"],
            self.data_config["frame_range"])
        print(self.dataset)

        self.ctn = ColorTransformNetwork(
            layers=model_config["CTN"]["num_layers"])
        self.ctn.cuda()
        self.ctn.eval()

        weight = torch.load(pretrained_path)
        self.ctn.load_state_dict(weight)

        gen = TemporalConstraintNetwork()
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["TCN"])

        dis = Discriminator()
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        t_dis = TemporalDiscriminator()
        self.t_dis, self.t_dis_opt = self._setting_model_optim(
            t_dis, model_config["temporal_discriminator"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = VideoColorizeLossCalculator()
        self.visualizer = Visualizer()
Esempio n. 6
0
    def __init__(
        self,
        config,
        outdir,
        modeldir,
        data_path,
        sketch_path,
    ):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir
        self.mask = self.train_config["mask"]

        self.dataset = IllustDataset(data_path, sketch_path,
                                     self.data_config["extension"],
                                     self.data_config["train_size"],
                                     self.data_config["valid_size"],
                                     self.data_config["color_space"],
                                     self.data_config["line_space"])
        print(self.dataset)

        if self.mask:
            in_ch = 6
        else:
            in_ch = 3

        gen = Generator(in_ch=in_ch)
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])

        dis = Discriminator()
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.lossfunc = Pix2pixCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
Esempio n. 7
0
    def __init__(
        self,
        config,
        outdir,
        modeldir,
        data_path,
        sketch_path,
    ):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"],
            self.data_config["src_perturbation"],
            self.data_config["tgt_perturbation"])
        print(self.dataset)

        gen = Generator()
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])

        dis = Discriminator()
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = SCFTLossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
Esempio n. 8
0
    def __init__(self,
                 config,
                 outdir,
                 modeldir,
                 data_path,
                 sketch_path,
                 flat_path,
                 pretrain_path=None):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.train_type = self.train_config["train_type"]

        if self.train_type == "multi":
            self.dataset = DanbooruFacesDataset(data_path,
                                                sketch_path,
                                                self.data_config["line_method"],
                                                self.data_config["extension"],
                                                self.data_config["train_size"],
                                                self.data_config["valid_size"],
                                                self.data_config["color_space"],
                                                self.data_config["line_space"])

        else:
            self.dataset = IllustDataset(data_path,
                                         sketch_path,
                                         flat_path,
                                         self.data_config["line_method"],
                                         self.data_config["extension"],
                                         self.data_config["train_size"],
                                         self.data_config["valid_size"],
                                         self.data_config["color_space"],
                                         self.data_config["line_space"])
        print(self.dataset)

        flat_gen = Generator(model_config["flat_generator"]["in_ch"],
                             num_layers=model_config["flat_generator"]["num_layers"],
                             attn_type=model_config["flat_generator"]["attn_type"],
                             )
        self.flat_gen, self.flat_gen_opt = self._setting_model_optim(flat_gen,
                                                                     model_config["flat_generator"])

        if self.train_type == "multi":
            weight = torch.load(pretrain_path)
            self.flat_gen.load_state_dict(weight)

        f_dis = Discriminator(model_config["flat_dis"]["in_ch"],
                              model_config["flat_dis"]["multi"])
        self.f_dis, self.f_dis_opt = self._setting_model_optim(f_dis,
                                                               model_config["flat_dis"])

        if self.train_type == "multi":
            bicycle_gen = BicycleGAN(model_config["bicycle_gan"]["in_ch"],
                                     latent_dim=model_config["bicycle_gan"]["l_dim"],
                                     num_layers=model_config["bicycle_gan"]["num_layers"])
            self.b_gen, self.b_gen_opt = self._setting_model_optim(bicycle_gen,
                                                                   model_config["bicycle_gan"])

            latent_enc = LatentEncoder(model_config["encoder"]["in_ch"],
                                       latent_dim=model_config["encoder"]["l_dim"])
            self.l_enc, self.l_enc_opt = self._setting_model_optim(latent_enc,
                                                                   model_config["encoder"])

            b_dis = Discriminator(model_config["bicycle_dis"]["in_ch"],
                                  model_config["bicycle_dis"]["multi"])
            self.b_dis, self.b_dis_opt = self._setting_model_optim(b_dis,
                                                                   model_config["bicycle_dis"])

            fixer = ColorFixer()
            self.fix, self.fix_opt = self._setting_model_optim(fixer,
                                                               model_config["fixer"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.out_filter = GuidedFilter(r=1, eps=1e-2)
        self.out_filter.cuda()

        self.lossfunc = DecomposeLossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
Esempio n. 9
0
def train(epochs, interval, batchsize, validsize, data_path, sketch_path,
          extension, img_size, outdir, modeldir, gen_learning_rate,
          dis_learning_rate, beta1, beta2):

    # Dataset Definition
    dataset = IllustDataset(data_path, sketch_path, extension)
    c_valid, l_valid = dataset.valid(validsize)
    print(dataset)
    collator = LineCollator(img_size)

    # Model & Optimizer Definition
    model = Style2Paint()
    model.cuda()
    model.train()
    gen_opt = torch.optim.Adam(model.parameters(),
                               lr=gen_learning_rate,
                               betas=(beta1, beta2))

    discriminator = Discriminator()
    discriminator.cuda()
    discriminator.train()
    dis_opt = torch.optim.Adam(discriminator.parameters(),
                               lr=dis_learning_rate,
                               betas=(beta1, beta2))

    vgg = Vgg19(requires_grad=False)
    vgg.cuda()
    vgg.eval()

    # Loss function definition
    lossfunc = Style2paintsLossCalculator()

    # Visualizer definition
    visualizer = Visualizer()

    iteration = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                collate_fn=collator,
                                drop_last=True)
        progress_bar = tqdm(dataloader)

        for index, data in enumerate(progress_bar):
            iteration += 1
            jit, war, line = data

            # Discriminator update
            y = model(line, war)
            loss = lossfunc.adversarial_disloss(discriminator, y.detach(), jit)

            dis_opt.zero_grad()
            loss.backward()
            dis_opt.step()

            # Generator update
            y = model(line, war)
            loss = lossfunc.adversarial_genloss(discriminator, y)
            loss += 10.0 * lossfunc.content_loss(y, jit)
            loss += lossfunc.style_and_perceptual_loss(vgg, y, jit)

            gen_opt.zero_grad()
            loss.backward()
            gen_opt.step()

            if iteration % interval == 1:
                torch.save(model.state_dict(),
                           f"{modeldir}/model_{iteration}.pt")

                with torch.no_grad():
                    y = model(l_valid, c_valid)

                c = c_valid.detach().cpu().numpy()
                l = l_valid.detach().cpu().numpy()
                y = y.detach().cpu().numpy()

                visualizer(l, c, y, outdir, iteration, validsize)

            print(f"iteration: {iteration} Loss: {loss.data}")
Esempio n. 10
0
def train(epochs,
          interval,
          batchsize,
          validsize,
          data_path,
          sketch_path,
          extension,
          img_size,
          outdir,
          modeldir,
          learning_rate):

    # Dataset Definition
    dataset = IllustDataset(data_path, sketch_path, extension)
    c_valid, l_valid = dataset.valid(validsize)
    print(dataset)
    collator = LineCollator(img_size)

    # Model & Optimizer Definition
    model = Style2Paint(attn_type="adain")
    model.cuda()
    model.train()
    gen_opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

    discriminator = Discriminator()
    discriminator.cuda()
    discriminator.train()
    dis_opt = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

    # Loss function definition
    lossfunc = Style2paintsLossCalculator()

    # Visualizer definition
    visualizer = Visualizer()

    iteration = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                collate_fn=collator,
                                drop_last=True)
        progress_bar = tqdm(dataloader)

        for index, data in enumerate(progress_bar):
            iteration += 1
            color, line = data
            y = model(line, color)
            loss = 0.01 * lossfunc.adversarial_disloss(discriminator, y.detach(), color)

            dis_opt.zero_grad()
            loss.backward()
            dis_opt.step()

            y = model(line, color)
            loss = 0.01 * lossfunc.adversarial_genloss(discriminator, y)
            loss += maeloss(y, color)
            loss += 0.001 * lossfunc.positive_enforcing_loss(y)

            gen_opt.zero_grad()
            loss.backward()
            gen_opt.step()

            if iteration % interval == 1:
                torch.save(model.state_dict(), f"{modeldir}/model_{iteration}.pt")

                with torch.no_grad():
                    y = model(l_valid, c_valid)

                c = c_valid.detach().cpu().numpy()
                l = l_valid.detach().cpu().numpy()
                y = y.detach().cpu().numpy()

                visualizer(l, c, y, outdir, iteration, validsize)

            print(f"iteration: {iteration} Loss: {loss.data}")