Exemplo n.º 1
0
def train(epochs, iterations, outdir, path, batchsize, validsize, model_type):
    # Dataset Definition
    dataloader = DatasetLoader(path)
    print(dataloader)
    t_valid, x_valid = dataloader(validsize, mode="valid")

    # Model & Optimizer Definition
    if model_type == 'ram':
        model = Model()
    elif model_type == 'gan':
        model = Generator()
    model.to_gpu()
    optimizer = set_optimizer(model)

    vgg = VGG()
    vgg.to_gpu()
    vgg_opt = set_optimizer(vgg)
    vgg.base.disable_update()

    # Loss Function Definition
    lossfunc = RAMLossFunction()
    print(lossfunc)

    # Evaluation Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            t_train, x_train = dataloader(batchsize, mode="train")

            y_train = model(x_train)
            y_feat = vgg(y_train)
            t_feat = vgg(t_train)
            loss = lossfunc.content_loss(y_train, t_train)
            loss += lossfunc.perceptual_loss(y_feat, t_feat)

            model.cleargrads()
            vgg.cleargrads()
            loss.backward()
            optimizer.update()
            vgg_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{outdir}/model_{epoch}.model", model)

                with chainer.using_config('train', False):
                    y_valid = model(x_valid)
                x = x_valid.data.get()
                y = y_valid.data.get()
                t = t_valid.data.get()

                evaluator(x, y, t, epoch, outdir)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemplo n.º 2
0
def train(epochs, iterations, batchsize, outdir, data_path):
    # Dataset Definition
    dataloader = DatasetLoader(data_path)

    # Model & Optimizer Definition
    #generator = Generator()
    generator = GeneratorWithCIN()
    generator.to_gpu()
    gen_opt = set_optimizer(generator, alpha=0.0002)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, alpha=0.0001)

    # Loss Function Definition
    lossfunc = StarGANVC2LossFunction()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            x_sp, x_label, y_sp, y_label = dataloader.train(batchsize)
            y_fake = generator(x_sp, F.concat([y_label, x_label]))
            y_fake.unchain_backward()

            loss = lossfunc.dis_loss(discriminator, y_fake, x_sp, y_label,
                                     x_label)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            y_fake = generator(x_sp, F.concat([y_label, x_label]))
            x_fake = generator(y_fake, F.concat([x_label, y_label]))
            x_identity = generator(x_sp, F.concat([x_label, x_label]))
            loss = lossfunc.gen_loss(discriminator, y_fake, x_fake, x_sp,
                                     F.concat([y_label, x_label]))
            if epoch < 50:
                loss += lossfunc.identity_loss(x_identity, x_sp)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"modeldirCIN/generator_{epoch}.model",
                                     generator)
                serializers.save_npz('discriminator.model', discriminator)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemplo n.º 3
0
def train(epochs, batchsize, iterations, nc_size, data_path, modeldir):
    # Dataset definition
    dataset = DatasetLoader(data_path, nc_size)

    # Model Definition & Optimizer Definition
    generator = Generator(nc_size)
    generator.to_gpu()
    gen_opt = set_optimizer(generator, 0.0001, 0.5)
    
    discriminator = Discriminator(nc_size)
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, 0.0001, 0.5)

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0
        for batch in range(0, iterations, batchsize):
            x, x_label, y, y_label = dataset.train(batchsize)

            y_fake = generator(x, y_label)
            y_fake.unchain_backward()

            loss = adversarial_loss_dis(discriminator, y_fake, x, y_label, x_label)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            sum_dis_loss += loss.data

            y_fake = generator(x, y_label)
            x_fake = generator(y_fake, x_label)
            x_id = generator(x, x_label)

            loss = adversarial_loss_gen(discriminator, y_fake, x_fake, x, y_label)

            if epoch < 20:
                loss += 10 * F.mean_absolute_error(x_id, x)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_gen_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)
                serializers.save_npz("discriminator.model", discriminator)

        print(f"epoch: {epoch} disloss: {sum_dis_loss/iterations} genloss: {sum_gen_loss/iterations}")
Exemplo n.º 4
0
def train(epochs, iterations, outdir, path, batchsize, validsize):
    # Dataset Definition
    dataloader = DatasetLoader(path)
    print(dataloader)
    t_valid, x_valid = dataloader(validsize, mode="valid")

    # Model & Optimizer Definition
    model = Generator()
    model.to_gpu()
    optimizer = set_optimizer(model)

    # Loss Function Definition
    lossfunc = ESRGANPretrainLossFunction()
    print(lossfunc)

    # Evaluation Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            t_train, x_train = dataloader(batchsize, mode="train")

            y_train = model(x_train)
            loss = lossfunc.content_loss(y_train, t_train)

            model.cleargrads()
            loss.backward()
            optimizer.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{outdir}/model_{epoch}.model", model)

                with chainer.using_config('train', False):
                    y_valid = model(x_valid)
                x = x_valid.data.get()
                y = y_valid.data.get()
                t = t_valid.data.get()

                evaluator(x, y, t, epoch, outdir)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemplo n.º 5
0
def train(epochs, iterations, batchsize, validsize, src_path, tgt_path,
          extension, img_size, outdir, modeldir, lr_dis, lr_gen, beta1, beta2):

    # Dataset definition
    dataset = DatasetLoader(src_path, tgt_path, extension, img_size)
    print(dataset)
    x_val, x_mask_val, y_val, y_mask_val = dataset.valid(validsize)

    # Model & Optimizer definition
    generator_xy = Generator()
    generator_xy.to_gpu()
    gen_xy_opt = set_optimizer(generator_xy, lr_gen, beta1, beta2)

    generator_yx = Generator()
    generator_yx.to_gpu()
    gen_yx_opt = set_optimizer(generator_yx, lr_gen, beta1, beta2)

    discriminator_y = Discriminator()
    discriminator_y.to_gpu()
    dis_y_opt = set_optimizer(discriminator_y, lr_dis, beta1, beta2)

    discriminator_x = Discriminator()
    discriminator_x.to_gpu()
    dis_x_opt = set_optimizer(discriminator_x, lr_dis, beta1, beta2)

    # Loss Function definition
    lossfunc = InstaGANLossFunction()

    # Visualizer definition
    visualize = Visualizer()

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0

        for batch in range(0, iterations, batchsize):
            x, x_mask, y, y_mask = dataset.train(batchsize)

            # discriminator update
            xy, xy_mask = generator_xy(x, x_mask)
            yx, yx_mask = generator_yx(y, y_mask)

            xy.unchain_backward()
            xy_mask.unchain_backward()
            yx.unchain_backward()
            yx_mask.unchain_backward()

            dis_loss = lossfunc.adversarial_dis_loss(discriminator_y, xy,
                                                     xy_mask, y, y_mask)
            dis_loss += lossfunc.adversarial_dis_loss(discriminator_x, yx,
                                                      yx_mask, x, x_mask)

            discriminator_y.cleargrads()
            discriminator_x.cleargrads()
            dis_loss.backward()
            dis_y_opt.update()
            dis_x_opt.update()

            sum_dis_loss += dis_loss.data

            # generator update
            xy, xy_mask = generator_xy(x, x_mask)
            yx, yx_mask = generator_yx(y, y_mask)

            xyx, xyx_mask = generator_yx(xy, xy_mask)
            yxy, yxy_mask = generator_xy(yx, yx_mask)

            x_id, x_mask_id = generator_yx(x, x_mask)
            y_id, y_mask_id = generator_xy(y, y_mask)

            gen_loss = lossfunc.adversarial_gen_loss(discriminator_y, xy,
                                                     xy_mask)
            gen_loss += lossfunc.adversarial_gen_loss(discriminator_x, yx,
                                                      yx_mask)

            gen_loss += lossfunc.cycle_consistency_loss(
                xyx, xyx_mask, x, x_mask)
            gen_loss += lossfunc.cycle_consistency_loss(
                yxy, yxy_mask, y, y_mask)

            gen_loss += lossfunc.identity_mapping_loss(x_id, x_mask_id, x,
                                                       x_mask)
            gen_loss += lossfunc.identity_mapping_loss(y_id, y_mask_id, y,
                                                       y_mask)

            gen_loss += lossfunc.context_preserving_loss(
                xy, xy_mask, x, x_mask)
            gen_loss += lossfunc.context_preserving_loss(
                yx, yx_mask, y, y_mask)

            generator_xy.cleargrads()
            generator_yx.cleargrads()
            gen_loss.backward()
            gen_xy_opt.update()
            gen_yx_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model",
                                     generator_xy)
                serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model",
                                     generator_yx)

                xy, xy_mask = generator_xy(x_val, x_mask_val)
                yx, yx_mask = generator_yx(y_val, y_mask_val)

                x = x_val.data.get()
                x_mask = x_mask_val.data.get()
                xy = xy.data.get()
                xy_mask = xy_mask.data.get()

                visualize(x,
                          x_mask,
                          xy,
                          xy_mask,
                          outdir,
                          epoch,
                          validsize,
                          switch="mtot")

                y = y_val.data.get()
                y_mask = y_mask_val.data.get()
                yx = yx.data.get()
                yx_mask = yx_mask.data.get()

                visualize(y,
                          y_mask,
                          yx,
                          yx_mask,
                          outdir,
                          epoch,
                          validsize,
                          switch="ttom")

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
Exemplo n.º 6
0
def train(epochs, iterations, batchsize, validsize, outdir, modeldir,
          data_path, extension, img_size, latent_dim, learning_rate, beta1,
          beta2, enable):

    # Dataset Definition
    dataloader = DataLoader(data_path, extension, img_size, latent_dim)
    print(dataloader)
    color_valid, line_valid = dataloader(validsize, mode="valid")
    noise_valid = dataloader.noise_generator(validsize)

    # Model Definition
    if enable:
        encoder = Encoder()
        encoder.to_gpu()
        enc_opt = set_optimizer(encoder)

    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator, learning_rate, beta1, beta2)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2)

    # Loss Funtion Definition
    lossfunc = GauGANLossFunction()

    # Evaluation Definition
    evaluator = Evaluaton()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            color, line = dataloader(batchsize)
            z = dataloader.noise_generator(batchsize)

            # Discriminator update
            if enable:
                mu, sigma = encoder(color)
                z = F.gaussian(mu, sigma)
            y = generator(z, line)

            y.unchain_backward()

            dis_loss = lossfunc.dis_loss(discriminator, F.concat([y, line]),
                                         F.concat([color, line]))

            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()
            dis_loss.unchain_backward()

            sum_dis_loss += dis_loss.data

            # Generator update
            z = dataloader.noise_generator(batchsize)

            if enable:
                mu, sigma = encoder(color)
                z = F.gaussian(mu, sigma)
            y = generator(z, line)

            gen_loss = lossfunc.gen_loss(discriminator, F.concat([y, line]),
                                         F.concat([color, line]))
            gen_loss += lossfunc.content_loss(y, color)

            if enable:
                gen_loss += 0.05 * F.gaussian_kl_divergence(mu,
                                                            sigma) / batchsize

            generator.cleargrads()
            if enable:
                encoder.cleargrads()
            gen_loss.backward()
            gen_opt.update()
            if enable:
                enc_opt.update()
            gen_loss.unchain_backward()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

                with chainer.using_config("train", False):
                    y = generator(noise_valid, line_valid)
                y = y.data.get()
                sr = line_valid.data.get()
                cr = color_valid.data.get()

                evaluator(y, cr, sr, outdir, epoch, validsize=validsize)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
Exemplo n.º 7
0
def train(epochs,
          iterations,
          batchsize,
          validsize,
          outdir,
          modeldir,
          src_path,
          tgt_path,
          extension,
          img_size,
          learning_rate,
          beta1
          ):

    # Dataset definition
    dataloader = DatasetLoader(src_path, tgt_path, extension, img_size)
    print(dataloader)
    src_val = dataloader.valid(validsize)

    # Model & Optimizer definition
    generator_xy = Generator()
    generator_xy.to_gpu()
    gen_xy_opt = set_optimizer(generator_xy, learning_rate, beta1)

    generator_yx = Generator()
    generator_yx.to_gpu()
    gen_yx_opt = set_optimizer(generator_yx, learning_rate, beta1)

    discriminator_y = Discriminator()
    discriminator_y.to_gpu()
    dis_y_opt = set_optimizer(discriminator_y, learning_rate, beta1)

    discriminator_x = Discriminator()
    discriminator_x.to_gpu()
    dis_x_opt = set_optimizer(discriminator_x, learning_rate, beta1)

    # LossFunction definition
    lossfunc = CycleGANLossCalculator()

    # Visualization
    visualizer = Visualization()

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0
        for batch in range(0, iterations, batchsize):
            x, y = dataloader.train(batchsize)

            # Discriminator update
            xy = generator_xy(x)
            yx = generator_yx(y)

            xy.unchain_backward()
            yx.unchain_backward()

            dis_loss_xy = lossfunc.dis_loss(discriminator_y, xy, y)
            dis_loss_yx = lossfunc.dis_loss(discriminator_x, yx, x)

            dis_loss = dis_loss_xy + dis_loss_yx

            discriminator_x.cleargrads()
            discriminator_y.cleargrads()
            dis_loss.backward()
            dis_x_opt.update()
            dis_y_opt.update()

            sum_dis_loss += dis_loss.data

            # Generator update
            xy = generator_xy(x)
            yx = generator_yx(y)

            xyx = generator_yx(xy)
            yxy = generator_xy(yx)

            y_id = generator_xy(y)
            x_id = generator_yx(x)

            # adversarial loss
            gen_loss_xy = lossfunc.gen_loss(discriminator_y, xy)
            gen_loss_yx = lossfunc.gen_loss(discriminator_x, yx)

            # cycle-consitency loss
            cycle_y = lossfunc.cycle_consitency_loss(yxy, y)
            cycle_x = lossfunc.cycle_consitency_loss(xyx, x)

            # identity mapping loss
            identity_y = lossfunc.identity_mapping_loss(y_id, y)
            identity_x = lossfunc.identity_mapping_loss(x_id, x)

            gen_loss = gen_loss_xy + gen_loss_yx + cycle_x + cycle_y + identity_x + identity_y

            generator_xy.cleargrads()
            generator_yx.cleargrads()
            gen_loss.backward()
            gen_xy_opt.update()
            gen_yx_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy)
                serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx)

                with chainer.using_config('train', False):
                    tgt = generator_xy(src_val)

                src = src_val.data.get()
                tgt = tgt.data.get()

                visualizer(src, tgt, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(F"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}")
Exemplo n.º 8
0
def train(epochs,
          iterations,
          batchsize,
          validsize,
          data_path,
          sketch_path,
          digi_path,
          extension,
          img_size,
          outdir,
          modeldir,
          pretrained_epoch,
          adv_weight,
          enf_weight,
          sn,
          bn,
          activ):

    # Dataset Definition
    dataloader = DataLoader(data_path, sketch_path, digi_path,
                            extension=extension, img_size=img_size)
    print(dataloader)
    color_valid, line_valid, mask_valid, ds_valid = dataloader(validsize, mode="valid")

    # Model & Optimizer Definition
    generator = SAGeneratorWithGuide(attn_type="sa", bn=bn, activ=activ)
    #generator = SAGenerator(attn_type="sa", base=64)
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator(sn=sn)
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    vgg = VGG()
    vgg.to_gpu()
    vgg_opt = set_optimizer(vgg)
    vgg.base.disable_update()

    # Loss Function Definition
    lossfunc = LossCalculator()

    # Evaluation Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            color, line, mask, mask_ds = dataloader(batchsize)
            line_input = F.concat([line, mask])

            extractor = vgg(mask, extract=True)
            extractor = F.average_pooling_2d(extractor, 3, 2, 1)
            extractor.unchain_backward()

            if epoch > pretrained_epoch:
                adv_weight = 0.1
                enf_weight = 0.0

            # Discriminator update
            fake, _ = generator(line_input, mask_ds, extractor)
            y_dis = discriminator(fake, extractor)
            t_dis = discriminator(color, extractor)
            loss = adv_weight * lossfunc.dis_hinge_loss(y_dis, t_dis)

            fake.unchain_backward()

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            # Generator update
            fake, guide = generator(line_input, mask_ds, extractor)
            y_dis = discriminator(fake, extractor)

            loss = adv_weight * lossfunc.gen_hinge_loss(y_dis)
            loss += enf_weight * lossfunc.positive_enforcing_loss(fake)
            loss += lossfunc.content_loss(fake, color)
            loss += 0.9 * lossfunc.content_loss(guide, color)
            loss += lossfunc.perceptual_loss(vgg, fake, color)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)

                extractor = vgg(line_valid, extract=True)
                extractor = F.average_pooling_2d(extractor, 3, 2, 1)
                extractor.unchain_backward()
                line_valid_input = F.concat([line_valid, mask_valid])

                with chainer.using_config('train', False):
                    y_valid, guide_valid = generator(line_valid_input, ds_valid, extractor)

                y_valid = y_valid.data.get()
                c_valid = color_valid.data.get()
                input_valid = line_valid_input.data.get()
                guide_valid = guide_valid.data.get()

                evaluator(y_valid, c_valid, input_valid, guide_valid, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemplo n.º 9
0
def train(epochs, iterations, batchsize, data_path, modeldir, extension,
          img_size, learning_rate, beta1, weight_decay):

    # Dataset definition
    dataset = DatasetLoader(data_path, extension, img_size)

    # Model & Optimizer definition
    generator = Generator(dataset.number)
    generator.to_gpu()
    gen_opt = set_optimizer(generator, learning_rate, beta1, weight_decay)

    discriminator = Discriminator(dataset.number)
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay)

    # Loss Function definition
    lossfunc = RelGANLossFunction()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x, x_label, y, y_label, z, z_label = dataset.train(batchsize)

            # Discriminator update
            # Adversairal loss
            a = y_label - x_label
            fake = generator(x, a)
            fake.unchain_backward()
            loss = lossfunc.adversarial_loss_dis(discriminator, fake, y)

            # Interpolation loss
            rnd = np.random.randint(2)
            if rnd == 0:
                alpha = xp.random.uniform(0, 0.5, size=batchsize)
            else:
                alpha = xp.random.uniform(0.5, 1.0, size=batchsize)
            alpha = chainer.as_variable(alpha.astype(xp.float32))
            alpha = F.tile(F.expand_dims(alpha, axis=1), (1, dataset.number))

            fake_0 = generator(x, y_label - y_label)
            fake_1 = generator(x, alpha * a)
            fake_0.unchain_backward()
            fake_1.unchain_backward()
            loss += 10 * lossfunc.interpolation_loss_dis(
                discriminator, fake_0, fake, fake_1, alpha, rnd)

            # Matching loss
            v2 = y_label - z_label
            v3 = z_label - x_label

            loss += lossfunc.matching_loss_dis(discriminator, x, fake, y, z, a,
                                               v2, v3)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            sum_dis_loss += loss.data

            # Generator update
            # Adversarial loss
            fake = generator(x, a)
            loss = lossfunc.adversarial_loss_gen(discriminator, fake)

            # Interpolation loss
            rnd = np.random.randint(2)
            if rnd == 0:
                alpha = xp.random.uniform(0, 0.5, size=batchsize)
            else:
                alpha = xp.random.uniform(0.5, 1.0, size=batchsize)
            alpha = chainer.as_variable(alpha.astype(xp.float32))
            alpha = F.tile(F.expand_dims(alpha, axis=1), (1, dataset.number))

            fake_alpha = generator(x, alpha * a)
            loss += 10 * lossfunc.interpolation_loss_gen(
                discriminator, fake_alpha)

            # Matching loss
            loss += lossfunc.matching_loss_gen(discriminator, x, fake, a)

            # Cycle-consistency loss
            cyc = generator(fake, -a)
            loss += 10 * F.mean_absolute_error(cyc, x)

            # Self-reconstruction loss
            fake_0 = generator(x, y_label - y_label)
            loss += 10 * F.mean_absolute_error(fake_0, x)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_gen_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

        print(
            f"epoch: {epoch} disloss: {sum_dis_loss/iterations} genloss: {sum_gen_loss/iterations}"
        )
Exemplo n.º 10
0
def train(epochs, iterations, batchsize, src_path, tgt_path, modeldir):
    # Dataset definition
    dataset = DatasetLoader(src_path, tgt_path)
    print(dataset)

    # Model & Optimizer Definition
    generator_xy = Generator()
    generator_xy.to_gpu()
    gen_xy_opt = set_optimizer(generator_xy)

    generator_yx = Generator()
    generator_yx.to_gpu()
    gen_yx_opt = set_optimizer(generator_yx)

    discriminator_y = MSDiscriminator()
    discriminator_y.to_gpu()
    dis_y_opt = set_optimizer(discriminator_y)

    discriminator_x = MSDiscriminator()
    discriminator_x.to_gpu()
    dis_x_opt = set_optimizer(discriminator_x)

    # Loss Function Definition
    lossfunc = CycleGANVC2LossFunction()

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0
        for batch in range(0, iterations, batchsize):
            x, y = dataset.train(batchsize)

            xy = generator_xy(x)
            yx = generator_yx(y)

            xy.unchain_backward()
            yx.unchain_backward()

            loss = lossfunc.adv_dis_loss(discriminator_y, xy, y)
            loss += lossfunc.adv_dis_loss(discriminator_x, yx, x)

            sum_dis_loss += loss.data

            discriminator_x.cleargrads()
            discriminator_y.cleargrads()
            loss.backward()
            dis_x_opt.update()
            dis_y_opt.update()
            loss.unchain_backward()

            xy = generator_xy(x)
            xyx = generator_yx(xy)
            id_y = generator_xy(y)

            yx = generator_yx(y)
            yxy = generator_xy(yx)
            id_x = generator_yx(x)

            loss = lossfunc.adv_gen_loss(discriminator_y, xy)
            loss += lossfunc.adv_gen_loss(discriminator_x, yx)

            cycle_loss_x = lossfunc.recon_loss(xyx, x)
            cycle_loss_y = lossfunc.recon_loss(yxy, y)
            cycle_loss = cycle_loss_x + cycle_loss_y

            identity_loss_x = lossfunc.recon_loss(id_y, y)
            identity_loss_y = lossfunc.recon_loss(id_x, x)
            identity_loss = identity_loss_x + identity_loss_y

            if epoch > 20:
                identity_weight = 0.0
            else:
                identity_weight = 5.0

            loss += 10 * cycle_loss + identity_weight * identity_loss

            generator_xy.cleargrads()
            generator_yx.cleargrads()
            loss.backward()
            gen_xy_opt.update()
            gen_yx_opt.update()
            loss.unchain_backward()

            sum_gen_loss += loss.data.get()

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_xy.model", generator_xy)
                serializers.save_npz(f"{modeldir}/generator_yx.model", generator_yx)

        print('epoch : {}'.format(epoch))
        print('Generator loss : {}'.format(sum_gen_loss / iterations))
        print('Discriminator loss : {}'.format(sum_dis_loss / iterations))
Exemplo n.º 11
0
def train_refine(epochs,
                 iterations,
                 batchsize,
                 validsize,
                 data_path,
                 sketch_path,
                 digi_path,
                 st_path,
                 extension,
                 img_size,
                 crop_size,
                 outdir,
                 modeldir,
                 adv_weight,
                 enf_weight):

    # Dataset Definition
    dataloader = RefineDataset(data_path, sketch_path, digi_path, st_path,
                               extension=extension, img_size=img_size, crop_size=crop_size)
    print(dataloader)
    color_valid, line_valid, mask_valid, ds_valid, cm_valid = dataloader(validsize, mode="valid")

    # Model & Optimizer Definition
    generator = SAGeneratorWithGuide(attn_type="sa", base=64, bn=True, activ=F.relu)
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    vgg = VGG()
    vgg.to_gpu()
    vgg_opt = set_optimizer(vgg)
    vgg.base.disable_update()

    # Loss Function Definition
    lossfunc = LossCalculator()

    # Evaluation Definition
    evaluator = Evaluation()

    iteration = 0

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            iteration += 1
            color, line, mask, mask_ds, color_mask = dataloader(batchsize)
            line_input = F.concat([line, mask])

            extractor = vgg(color_mask, extract=True)
            extractor = F.average_pooling_2d(extractor, 3, 2, 1)
            extractor.unchain_backward()

            # Discriminator update
            fake, _ = generator(line_input, mask_ds, extractor)
            y_dis = discriminator(fake, extractor)
            t_dis = discriminator(color, extractor)
            loss = adv_weight * lossfunc.dis_hinge_loss(y_dis, t_dis)

            fake.unchain_backward()

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            sum_dis_loss += loss.data

            # Generator update
            fake, guide = generator(line_input, mask_ds, extractor)
            y_dis = discriminator(fake, extractor)

            loss = adv_weight * lossfunc.gen_hinge_loss(y_dis)
            loss += lossfunc.content_loss(fake, color)
            loss += 0.9 * lossfunc.content_loss(guide, color)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_gen_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)

                extractor = vgg(cm_valid, extract=True)
                extractor = F.average_pooling_2d(extractor, 3, 2, 1)
                extractor.unchain_backward()
                line_valid_input = F.concat([line_valid, mask_valid])

                with chainer.using_config('train', False):
                    y_valid, guide_valid = generator(line_valid_input, ds_valid, extractor)

                y_valid = y_valid.data.get()
                c_valid = color_valid.data.get()
                input_valid = line_valid_input.data.get()
                cm_val = cm_valid.data.get()
                guide_valid = guide_valid.data.get()
                input_valid = np.concatenate([input_valid[:, 3:6], cm_val], axis=1)

                evaluator(y_valid, c_valid, input_valid, guide_valid, outdir, epoch, validsize)

            print(f"iter: {iteration} dis loss: {sum_dis_loss} gen loss: {gen_loss}")
Exemplo n.º 12
0
def train(epochs,
          iterations,
          dataset_path,
          test_path,
          outdir,
          batchsize,
          testsize,
          recon_weight,
          fm_weight,
          gp_weight,
          spectral_norm=False):
    # Dataset Definition
    dataloader = DatasetLoader(dataset_path, test_path)
    c_valid, s_valid = dataloader.test(testsize)

    # Model & Optimizer Definition
    if spectral_norm:
        generator = SNGenerator()
    else:
        generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    # Loss Function Definition
    lossfunc = FUNITLossFunction()

    # Evaluator Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            c, ci, s, si = dataloader.train(batchsize)

            y = generator(c, s)
            y.unchain_backward()

            loss = lossfunc.dis_loss(discriminator, y, s, si)
            loss += lossfunc.gradient_penalty(discriminator, s, y, si)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            y_conert = generator(c, s)
            y_recon = generator(c, c)

            adv_loss, recon_loss, fm_loss = lossfunc.gen_loss(
                discriminator, y_conert, y_recon, s, c, si, ci)
            loss = adv_loss + recon_weight * recon_loss + fm_weight * fm_loss

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz('generator.model', generator)
                serializers.save_npz('discriminator.model', discriminator)

                with chainer.using_config('train', False):
                    y = generator(c_valid, s_valid)
                y.unchain_backward()

                y = y.data.get()
                c = c_valid.data.get()
                s = s_valid.data.get()

                evaluator(y, c, s, outdir, epoch, testsize)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemplo n.º 13
0
def train(epochs, iterations, batchsize, testsize, img_path, seg_path, outdir,
          modeldir, n_dis, mode):
    # Dataset Definition
    dataloader = DatasetLoader(img_path, seg_path)
    print(dataloader)
    valid_noise = dataloader.test(testsize)

    # Model & Optimizer Definition
    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    # Loss Function Definition
    lossfunc = SGANLossFunction()

    # Evaluation Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            for _ in range(n_dis):
                t, s, noise = dataloader.train(batchsize)
                y_img, y_seg = generator(noise)

                loss = lossfunc.dis_loss(discriminator, y_img, y_seg, t, s)
                loss += lossfunc.gradient_penalty(discriminator,
                                                  y_img,
                                                  y_seg,
                                                  t,
                                                  s,
                                                  mode=mode)

                discriminator.cleargrads()
                loss.backward()
                dis_opt.update()
                loss.unchain_backward()

            _, _, noise = dataloader.train(batchsize)
            y_img, y_seg = generator(noise)

            loss = lossfunc.gen_loss(discriminator, y_img, y_seg)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss = loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)
                serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model",
                                     discriminator)

                with chainer.using_config('train', False):
                    y_img, y_seg = generator(valid_noise)
                y_img = y_img.data.get()
                y_seg = y_seg.data.get()

                evaluator(y_img, y_seg, epoch, outdir, testsize=testsize)

        print(f"epoh: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemplo n.º 14
0
    input_dim=args.input_dim,
    n_heads=args.n_heads,
    n_blocks=args.n_blocks,
    dropout=args.dropout,
    ff_hidden_dim=4 * args.input_dim if not args.ff_hidden_dim else args.ff_hidden_dim,
    transformer_activation=args.transformer_activation,
    mlp_hidden_dims=mlp_hidden_dims,
    mlp_activation=args.mlp_activation,
    mlp_batchnorm=args.mlp_batchnorm,
    mlp_batchnorm_last=args.mlp_batchnorm_last,
    mlp_linear_first=args.mlp_linear_first,
)

model = WideDeep(wide=wide, deeptabular=deeptabular)

optimizers = set_optimizer(model, args)

steps_per_epoch = (X_tab_train.shape[0] // args.batch_size) + 1
lr_schedulers = set_lr_scheduler(optimizers, steps_per_epoch, args)

early_stopping = EarlyStopping(
    monitor=args.monitor,
    min_delta=args.early_stop_delta,
    patience=args.early_stop_patience,
)

trainer = Trainer(
    model,
    objective="binary",
    optimizers=optimizers,
    lr_schedulers=lr_schedulers,
Exemplo n.º 15
0
def train(epochs, iterations, batchsize, modeldir, extension, time_width,
          mel_bins, sampling_rate, g_learning_rate, d_learning_rate, beta1,
          beta2, identity_epoch, adv_type, residual_flag, data_path):

    # Dataset Definition
    dataloader = DatasetLoader(data_path)

    # Model & Optimizer Definition
    generator = GeneratorWithCIN(adv_type=adv_type)
    generator.to_gpu()
    gen_opt = set_optimizer(generator, g_learning_rate, beta1, beta2)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, d_learning_rate, beta1, beta2)

    # Loss Function Definition
    lossfunc = StarGANVC2LossFunction()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x_sp, x_label, y_sp, y_label = dataloader.train(batchsize)

            if adv_type == 'sat':
                y_fake = generator(x_sp, F.concat([y_label, x_label]))
            elif adv_type == 'orig':
                y_fake = generator(x_sp, y_label)
            else:
                raise AttributeError

            y_fake.unchain_backward()

            if adv_type == 'sat':
                advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss(
                    discriminator, y_fake, x_sp, F.concat([y_label, x_label]),
                    F.concat([x_label, y_label]), residual_flag)
            elif adv_type == 'orig':
                advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss(
                    discriminator, y_fake, x_sp, y_label, x_label,
                    residual_flag)
            else:
                raise AttributeError

            dis_loss = advloss_dis_real + advloss_dis_fake
            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()
            dis_loss.unchain_backward()

            if adv_type == 'sat':
                y_fake = generator(x_sp, F.concat([y_label, x_label]))
                x_fake = generator(y_fake, F.concat([x_label, y_label]))
                x_identity = generator(x_sp, F.concat([x_label, x_label]))
                advloss_gen_fake, cycle_loss = lossfunc.gen_loss(
                    discriminator, y_fake, x_fake, x_sp,
                    F.concat([y_label, x_label]), residual_flag)
            elif adv_type == 'orig':
                y_fake = generator(x_sp, y_label)
                x_fake = generator(y_fake, x_label)
                x_identity = generator(x_sp, x_label)
                advloss_gen_fake, cycle_loss = lossfunc.gen_loss(
                    discriminator, y_fake, x_fake, x_sp, y_label,
                    residual_flag)
            else:
                raise AttributeError

            if epoch < identity_epoch:
                identity_loss = lossfunc.identity_loss(x_identity, x_sp)
            else:
                identity_loss = call_zeros(advloss_dis_fake)

            gen_loss = advloss_gen_fake + cycle_loss + identity_loss
            generator.cleargrads()
            gen_loss.backward()
            gen_opt.update()
            gen_loss.unchain_backward()

            sum_dis_loss += dis_loss.data
            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
Exemplo n.º 16
0
def train(epochs, iterations, batchsize, validsize, outdir, modeldir,
          extension, train_size, valid_size, data_path, sketch_path, digi_path,
          learning_rate, beta1, weight_decay):

    # Dataset definition
    dataset = DatasetLoader(data_path, sketch_path, digi_path, extension,
                            train_size, valid_size)
    print(dataset)
    x_val, t_val = dataset.valid(validsize)

    # Model & Optimizer definition
    unet = UNet()
    unet.to_gpu()
    unet_opt = set_optimizer(unet, learning_rate, beta1, weight_decay)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay)

    # Loss function definition
    lossfunc = Pix2pixLossCalculator()

    # Visualization definition
    visualizer = Visualizer()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x, t = dataset.train(batchsize)

            # Discriminator update
            y = unet(x)
            y.unchain_backward()

            dis_loss = lossfunc.dis_loss(discriminator, y, t)

            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()

            sum_dis_loss += dis_loss.data

            # Generator update
            y = unet(x)

            gen_loss = lossfunc.gen_loss(discriminator, y)
            gen_loss += lossfunc.content_loss(y, t)

            unet.cleargrads()
            gen_loss.backward()
            unet_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/unet_{epoch}.model", unet)

                with chainer.using_config("train", False):
                    y = unet(x_val)

                x = x_val.data.get()
                t = t_val.data.get()
                y = y.data.get()

                visualizer(x, t, y, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}"
        )
Exemplo n.º 17
0
def train(epochs, iterations, outdir, path, batchsize, validsize,
          adv_weight, content_weight):
    # Dataset Definition
    dataloader = DatasetLoader(path)
    print(dataloader)
    t_valid, x_valid = dataloader(validsize, mode="valid")

    # Model & Optimizer Definition
    model = Generator()
    model.to_gpu()
    optimizer = set_optimizer(model)
    serializers.load_npz('./outdir_pretrain/model_80.model', model)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    vgg = VGG()
    vgg.to_gpu()
    vgg_opt = set_optimizer(vgg)
    vgg.base.disable_update()

    # Loss Function Definition
    lossfunc = ESRGANLossFunction()
    print(lossfunc)

    # Evaluation Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            t_train, x_train = dataloader(batchsize, mode="train")

            y_train = model(x_train)
            y_train.unchain_backward()
            loss = adv_weight * lossfunc.dis_hinge_loss(discriminator, y_train, t_train)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            y_train = model(x_train)
            loss = adv_weight * lossfunc.gen_hinge_loss(discriminator, y_train)
            loss += content_weight * lossfunc.content_loss(y_train, t_train)
            loss += lossfunc.perceptual_loss(vgg, y_train, t_train)

            model.cleargrads()
            vgg.cleargrads()
            loss.backward()
            optimizer.update()
            vgg_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{outdir}/model_{epoch}.model", model)

                with chainer.using_config('train', False):
                    y_valid = model(x_valid)
                x = x_valid.data.get()
                y = y_valid.data.get()
                t = t_valid.data.get()

                evaluator(x, y, t, epoch, outdir)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemplo n.º 18
0
def train(epochs, iterations, batchsize, testsize, outdir, modeldir, n_dis,
          img_path, tag_path):
    # Dataset Definition
    dataloader = DatasetLoader(img_path, tag_path)
    zvis_valid, ztag_valid = dataloader.valid(batchsize)
    noise_valid = F.concat([zvis_valid, ztag_valid])

    # Model & Optimizer Definition
    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    # Loss Functio Definition
    lossfunc = RGANLossFunction()

    # Evaluation
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            for _ in range(n_dis):
                zvis, ztag, img, tag = dataloader.train(batchsize)

                y = generator(F.concat([zvis, ztag]))
                y.unchain_backward()

                loss = lossfunc.dis_loss(discriminator, y, img, tag, ztag)
                loss += lossfunc.gradient_penalty(discriminator, img, tag)

                discriminator.cleargrads()
                loss.backward()
                dis_opt.update()
                loss.unchain_backward()

            zvis, ztag, _, _ = dataloader.train(batchsize)

            y = generator(F.concat([zvis, ztag]))

            loss = lossfunc.gen_loss(discriminator, y, ztag)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)
                serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator)

                with chainer.using_config('train', False):
                    y = generator(noise_valid)
                y = y.data.get()

                evaluator(y, outdir, epoch, testsize)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")