def infer(testsize,
          outdir,
          model_path,
          con_path,
          sty_path,
          extension,
          coord_size,
          crop_size,
          alpha,
          ):
    
    # Dataset definition
    dataloader = DatasetLoader(con_path, sty_path, extension, coord_size, crop_size)
    print(dataloader)
    con_valid, sty_valid = dataloader.valid(testsize)

    # Mode & Optimizer defnition
    decoder = Decoder()
    decoder.to_gpu()
    serializers.load_npz(model_path, decoder)

    vgg = VGG()
    vgg.to_gpu()

    # Visualizer definition
    visualizer = Visualizer()

    with chainer.using_config("train", False):
        style_feat_list = vgg(sty_valid)
        content_feat = vgg(con_valid)[-1]

        t = adain(content_feat, style_feat_list[-1])
        t = alpha * t + (1 - alpha) * content_feat
        g_t = decoder(t)

    g_t = g_t.data.get()
    con = con_valid.data.get()
    sty = sty_valid.data.get()

    visualizer(con, sty, g_t, outdir, 0, testsize)
Beispiel #2
0
def train(epochs, iterations, batchsize, validsize, src_path, tgt_path,
          extension, img_size, outdir, modeldir, lr_dis, lr_gen, beta1, beta2):

    # Dataset definition
    dataset = DatasetLoader(src_path, tgt_path, extension, img_size)
    print(dataset)
    x_val, x_mask_val, y_val, y_mask_val = dataset.valid(validsize)

    # Model & Optimizer definition
    generator_xy = Generator()
    generator_xy.to_gpu()
    gen_xy_opt = set_optimizer(generator_xy, lr_gen, beta1, beta2)

    generator_yx = Generator()
    generator_yx.to_gpu()
    gen_yx_opt = set_optimizer(generator_yx, lr_gen, beta1, beta2)

    discriminator_y = Discriminator()
    discriminator_y.to_gpu()
    dis_y_opt = set_optimizer(discriminator_y, lr_dis, beta1, beta2)

    discriminator_x = Discriminator()
    discriminator_x.to_gpu()
    dis_x_opt = set_optimizer(discriminator_x, lr_dis, beta1, beta2)

    # Loss Function definition
    lossfunc = InstaGANLossFunction()

    # Visualizer definition
    visualize = Visualizer()

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0

        for batch in range(0, iterations, batchsize):
            x, x_mask, y, y_mask = dataset.train(batchsize)

            # discriminator update
            xy, xy_mask = generator_xy(x, x_mask)
            yx, yx_mask = generator_yx(y, y_mask)

            xy.unchain_backward()
            xy_mask.unchain_backward()
            yx.unchain_backward()
            yx_mask.unchain_backward()

            dis_loss = lossfunc.adversarial_dis_loss(discriminator_y, xy,
                                                     xy_mask, y, y_mask)
            dis_loss += lossfunc.adversarial_dis_loss(discriminator_x, yx,
                                                      yx_mask, x, x_mask)

            discriminator_y.cleargrads()
            discriminator_x.cleargrads()
            dis_loss.backward()
            dis_y_opt.update()
            dis_x_opt.update()

            sum_dis_loss += dis_loss.data

            # generator update
            xy, xy_mask = generator_xy(x, x_mask)
            yx, yx_mask = generator_yx(y, y_mask)

            xyx, xyx_mask = generator_yx(xy, xy_mask)
            yxy, yxy_mask = generator_xy(yx, yx_mask)

            x_id, x_mask_id = generator_yx(x, x_mask)
            y_id, y_mask_id = generator_xy(y, y_mask)

            gen_loss = lossfunc.adversarial_gen_loss(discriminator_y, xy,
                                                     xy_mask)
            gen_loss += lossfunc.adversarial_gen_loss(discriminator_x, yx,
                                                      yx_mask)

            gen_loss += lossfunc.cycle_consistency_loss(
                xyx, xyx_mask, x, x_mask)
            gen_loss += lossfunc.cycle_consistency_loss(
                yxy, yxy_mask, y, y_mask)

            gen_loss += lossfunc.identity_mapping_loss(x_id, x_mask_id, x,
                                                       x_mask)
            gen_loss += lossfunc.identity_mapping_loss(y_id, y_mask_id, y,
                                                       y_mask)

            gen_loss += lossfunc.context_preserving_loss(
                xy, xy_mask, x, x_mask)
            gen_loss += lossfunc.context_preserving_loss(
                yx, yx_mask, y, y_mask)

            generator_xy.cleargrads()
            generator_yx.cleargrads()
            gen_loss.backward()
            gen_xy_opt.update()
            gen_yx_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model",
                                     generator_xy)
                serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model",
                                     generator_yx)

                xy, xy_mask = generator_xy(x_val, x_mask_val)
                yx, yx_mask = generator_yx(y_val, y_mask_val)

                x = x_val.data.get()
                x_mask = x_mask_val.data.get()
                xy = xy.data.get()
                xy_mask = xy_mask.data.get()

                visualize(x,
                          x_mask,
                          xy,
                          xy_mask,
                          outdir,
                          epoch,
                          validsize,
                          switch="mtot")

                y = y_val.data.get()
                y_mask = y_mask_val.data.get()
                yx = yx.data.get()
                yx_mask = yx_mask.data.get()

                visualize(y,
                          y_mask,
                          yx,
                          yx_mask,
                          outdir,
                          epoch,
                          validsize,
                          switch="ttom")

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
Beispiel #3
0
def train(epochs,
          iterations,
          batchsize,
          validsize,
          outdir,
          modeldir,
          src_path,
          tgt_path,
          extension,
          img_size,
          learning_rate,
          beta1
          ):

    # Dataset definition
    dataloader = DatasetLoader(src_path, tgt_path, extension, img_size)
    print(dataloader)
    src_val = dataloader.valid(validsize)

    # Model & Optimizer definition
    generator_xy = Generator()
    generator_xy.to_gpu()
    gen_xy_opt = set_optimizer(generator_xy, learning_rate, beta1)

    generator_yx = Generator()
    generator_yx.to_gpu()
    gen_yx_opt = set_optimizer(generator_yx, learning_rate, beta1)

    discriminator_y = Discriminator()
    discriminator_y.to_gpu()
    dis_y_opt = set_optimizer(discriminator_y, learning_rate, beta1)

    discriminator_x = Discriminator()
    discriminator_x.to_gpu()
    dis_x_opt = set_optimizer(discriminator_x, learning_rate, beta1)

    # LossFunction definition
    lossfunc = CycleGANLossCalculator()

    # Visualization
    visualizer = Visualization()

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0
        for batch in range(0, iterations, batchsize):
            x, y = dataloader.train(batchsize)

            # Discriminator update
            xy = generator_xy(x)
            yx = generator_yx(y)

            xy.unchain_backward()
            yx.unchain_backward()

            dis_loss_xy = lossfunc.dis_loss(discriminator_y, xy, y)
            dis_loss_yx = lossfunc.dis_loss(discriminator_x, yx, x)

            dis_loss = dis_loss_xy + dis_loss_yx

            discriminator_x.cleargrads()
            discriminator_y.cleargrads()
            dis_loss.backward()
            dis_x_opt.update()
            dis_y_opt.update()

            sum_dis_loss += dis_loss.data

            # Generator update
            xy = generator_xy(x)
            yx = generator_yx(y)

            xyx = generator_yx(xy)
            yxy = generator_xy(yx)

            y_id = generator_xy(y)
            x_id = generator_yx(x)

            # adversarial loss
            gen_loss_xy = lossfunc.gen_loss(discriminator_y, xy)
            gen_loss_yx = lossfunc.gen_loss(discriminator_x, yx)

            # cycle-consitency loss
            cycle_y = lossfunc.cycle_consitency_loss(yxy, y)
            cycle_x = lossfunc.cycle_consitency_loss(xyx, x)

            # identity mapping loss
            identity_y = lossfunc.identity_mapping_loss(y_id, y)
            identity_x = lossfunc.identity_mapping_loss(x_id, x)

            gen_loss = gen_loss_xy + gen_loss_yx + cycle_x + cycle_y + identity_x + identity_y

            generator_xy.cleargrads()
            generator_yx.cleargrads()
            gen_loss.backward()
            gen_xy_opt.update()
            gen_yx_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy)
                serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx)

                with chainer.using_config('train', False):
                    tgt = generator_xy(src_val)

                src = src_val.data.get()
                tgt = tgt.data.get()

                visualizer(src, tgt, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(F"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}")
Beispiel #4
0
def train(epochs, iterations, batchsize, validsize, outdir, modeldir,
          extension, train_size, valid_size, data_path, sketch_path, digi_path,
          learning_rate, beta1, weight_decay):

    # Dataset definition
    dataset = DatasetLoader(data_path, sketch_path, digi_path, extension,
                            train_size, valid_size)
    print(dataset)
    x_val, t_val = dataset.valid(validsize)

    # Model & Optimizer definition
    unet = UNet()
    unet.to_gpu()
    unet_opt = set_optimizer(unet, learning_rate, beta1, weight_decay)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay)

    # Loss function definition
    lossfunc = Pix2pixLossCalculator()

    # Visualization definition
    visualizer = Visualizer()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x, t = dataset.train(batchsize)

            # Discriminator update
            y = unet(x)
            y.unchain_backward()

            dis_loss = lossfunc.dis_loss(discriminator, y, t)

            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()

            sum_dis_loss += dis_loss.data

            # Generator update
            y = unet(x)

            gen_loss = lossfunc.gen_loss(discriminator, y)
            gen_loss += lossfunc.content_loss(y, t)

            unet.cleargrads()
            gen_loss.backward()
            unet_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/unet_{epoch}.model", unet)

                with chainer.using_config("train", False):
                    y = unet(x_val)

                x = x_val.data.get()
                t = t_val.data.get()
                y = y.data.get()

                visualizer(x, t, y, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}"
        )
def train(epochs, iterations, batchsize, testsize, outdir, modeldir, n_dis,
          img_path, tag_path):
    # Dataset Definition
    dataloader = DatasetLoader(img_path, tag_path)
    zvis_valid, ztag_valid = dataloader.valid(batchsize)
    noise_valid = F.concat([zvis_valid, ztag_valid])

    # Model & Optimizer Definition
    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    # Loss Functio Definition
    lossfunc = RGANLossFunction()

    # Evaluation
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            for _ in range(n_dis):
                zvis, ztag, img, tag = dataloader.train(batchsize)

                y = generator(F.concat([zvis, ztag]))
                y.unchain_backward()

                loss = lossfunc.dis_loss(discriminator, y, img, tag, ztag)
                loss += lossfunc.gradient_penalty(discriminator, img, tag)

                discriminator.cleargrads()
                loss.backward()
                dis_opt.update()
                loss.unchain_backward()

            zvis, ztag, _, _ = dataloader.train(batchsize)

            y = generator(F.concat([zvis, ztag]))

            loss = lossfunc.gen_loss(discriminator, y, ztag)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)
                serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator)

                with chainer.using_config('train', False):
                    y = generator(noise_valid)
                y = y.data.get()

                evaluator(y, outdir, epoch, testsize)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")