Esempio n. 1
0
def train(epochs, interval, batchsize, validsize, data_path, sketch_path,
          extension, img_size, outdir, modeldir, gen_learning_rate,
          dis_learning_rate, beta1, beta2):

    # Dataset Definition
    dataset = IllustDataset(data_path, sketch_path, extension)
    c_valid, l_valid = dataset.valid(validsize)
    print(dataset)
    collator = LineCollator(img_size)

    # Model & Optimizer Definition
    model = Style2Paint()
    model.cuda()
    model.train()
    gen_opt = torch.optim.Adam(model.parameters(),
                               lr=gen_learning_rate,
                               betas=(beta1, beta2))

    discriminator = Discriminator()
    discriminator.cuda()
    discriminator.train()
    dis_opt = torch.optim.Adam(discriminator.parameters(),
                               lr=dis_learning_rate,
                               betas=(beta1, beta2))

    vgg = Vgg19(requires_grad=False)
    vgg.cuda()
    vgg.eval()

    # Loss function definition
    lossfunc = Style2paintsLossCalculator()

    # Visualizer definition
    visualizer = Visualizer()

    iteration = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                collate_fn=collator,
                                drop_last=True)
        progress_bar = tqdm(dataloader)

        for index, data in enumerate(progress_bar):
            iteration += 1
            jit, war, line = data

            # Discriminator update
            y = model(line, war)
            loss = lossfunc.adversarial_disloss(discriminator, y.detach(), jit)

            dis_opt.zero_grad()
            loss.backward()
            dis_opt.step()

            # Generator update
            y = model(line, war)
            loss = lossfunc.adversarial_genloss(discriminator, y)
            loss += 10.0 * lossfunc.content_loss(y, jit)
            loss += lossfunc.style_and_perceptual_loss(vgg, y, jit)

            gen_opt.zero_grad()
            loss.backward()
            gen_opt.step()

            if iteration % interval == 1:
                torch.save(model.state_dict(),
                           f"{modeldir}/model_{iteration}.pt")

                with torch.no_grad():
                    y = model(l_valid, c_valid)

                c = c_valid.detach().cpu().numpy()
                l = l_valid.detach().cpu().numpy()
                y = y.detach().cpu().numpy()

                visualizer(l, c, y, outdir, iteration, validsize)

            print(f"iteration: {iteration} Loss: {loss.data}")
Esempio n. 2
0
def train(epochs,
          interval,
          batchsize,
          validsize,
          data_path,
          sketch_path,
          extension,
          img_size,
          outdir,
          modeldir,
          learning_rate):

    # Dataset Definition
    dataset = IllustDataset(data_path, sketch_path, extension)
    c_valid, l_valid = dataset.valid(validsize)
    print(dataset)
    collator = LineCollator(img_size)

    # Model & Optimizer Definition
    model = Style2Paint(attn_type="adain")
    model.cuda()
    model.train()
    gen_opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

    discriminator = Discriminator()
    discriminator.cuda()
    discriminator.train()
    dis_opt = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

    # Loss function definition
    lossfunc = Style2paintsLossCalculator()

    # Visualizer definition
    visualizer = Visualizer()

    iteration = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                collate_fn=collator,
                                drop_last=True)
        progress_bar = tqdm(dataloader)

        for index, data in enumerate(progress_bar):
            iteration += 1
            color, line = data
            y = model(line, color)
            loss = 0.01 * lossfunc.adversarial_disloss(discriminator, y.detach(), color)

            dis_opt.zero_grad()
            loss.backward()
            dis_opt.step()

            y = model(line, color)
            loss = 0.01 * lossfunc.adversarial_genloss(discriminator, y)
            loss += maeloss(y, color)
            loss += 0.001 * lossfunc.positive_enforcing_loss(y)

            gen_opt.zero_grad()
            loss.backward()
            gen_opt.step()

            if iteration % interval == 1:
                torch.save(model.state_dict(), f"{modeldir}/model_{iteration}.pt")

                with torch.no_grad():
                    y = model(l_valid, c_valid)

                c = c_valid.detach().cpu().numpy()
                l = l_valid.detach().cpu().numpy()
                y = y.detach().cpu().numpy()

                visualizer(l, c, y, outdir, iteration, validsize)

            print(f"iteration: {iteration} Loss: {loss.data}")