Exemple #1
0
def train(epochs, iterations, batchsize, outdir, data_path):
    # Dataset Definition
    dataloader = DatasetLoader(data_path)

    # Model & Optimizer Definition
    #generator = Generator()
    generator = GeneratorWithCIN()
    generator.to_gpu()
    gen_opt = set_optimizer(generator, alpha=0.0002)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, alpha=0.0001)

    # Loss Function Definition
    lossfunc = StarGANVC2LossFunction()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            x_sp, x_label, y_sp, y_label = dataloader.train(batchsize)
            y_fake = generator(x_sp, F.concat([y_label, x_label]))
            y_fake.unchain_backward()

            loss = lossfunc.dis_loss(discriminator, y_fake, x_sp, y_label,
                                     x_label)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            y_fake = generator(x_sp, F.concat([y_label, x_label]))
            x_fake = generator(y_fake, F.concat([x_label, y_label]))
            x_identity = generator(x_sp, F.concat([x_label, x_label]))
            loss = lossfunc.gen_loss(discriminator, y_fake, x_fake, x_sp,
                                     F.concat([y_label, x_label]))
            if epoch < 50:
                loss += lossfunc.identity_loss(x_identity, x_sp)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"modeldirCIN/generator_{epoch}.model",
                                     generator)
                serializers.save_npz('discriminator.model', discriminator)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
def train(epochs, batchsize, iterations, nc_size, data_path, modeldir):
    # Dataset definition
    dataset = DatasetLoader(data_path, nc_size)

    # Model Definition & Optimizer Definition
    generator = Generator(nc_size)
    generator.to_gpu()
    gen_opt = set_optimizer(generator, 0.0001, 0.5)
    
    discriminator = Discriminator(nc_size)
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, 0.0001, 0.5)

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0
        for batch in range(0, iterations, batchsize):
            x, x_label, y, y_label = dataset.train(batchsize)

            y_fake = generator(x, y_label)
            y_fake.unchain_backward()

            loss = adversarial_loss_dis(discriminator, y_fake, x, y_label, x_label)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            sum_dis_loss += loss.data

            y_fake = generator(x, y_label)
            x_fake = generator(y_fake, x_label)
            x_id = generator(x, x_label)

            loss = adversarial_loss_gen(discriminator, y_fake, x_fake, x, y_label)

            if epoch < 20:
                loss += 10 * F.mean_absolute_error(x_id, x)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_gen_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)
                serializers.save_npz("discriminator.model", discriminator)

        print(f"epoch: {epoch} disloss: {sum_dis_loss/iterations} genloss: {sum_gen_loss/iterations}")
Exemple #3
0
            seg.unchain_backward()

            std_data = xp.std(t.data, axis=0, keepdims=True)
            rnd_x = xp.random.uniform(0, 1, t.shape).astype(xp.float32)
            x_perturbed = rnd_x * t + (1 - rnd_x) * x
            s_perturbed = rnd_x * s + (1 - rnd_x) * seg

            y_perturbed = discriminator(x_perturbed, s_perturbed)
            grad, = chainer.grad([y_perturbed], [x_perturbed],
                                 enable_double_backprop=True)
            grad = F.sqrt(F.batch_l2_norm_squared(grad))
            loss_grad = lambda1 * F.mean_squared_error(grad,
                                                       xp.ones_like(grad.data))
            dis_loss += loss_grad

            discriminator.cleargrads()
            dis_loss.backward()
            dis_loss.unchain_backward()
            dis_opt.update()

        z = chainer.as_variable(
            xp.random.uniform(-1, 1, (batchsize, 256)).astype(xp.float32))
        x, seg = generator(z)
        fake = discriminator(x, seg)
        gen_loss = loss_hinge_gen(fake)

        generator.cleargrads()
        gen_loss.backward()
        gen_loss.unchain_backward()
        gen_opt.update()
Exemple #4
0
def train(epochs, iterations, batchsize, validsize, src_path, tgt_path,
          extension, img_size, outdir, modeldir, lr_dis, lr_gen, beta1, beta2):

    # Dataset definition
    dataset = DatasetLoader(src_path, tgt_path, extension, img_size)
    print(dataset)
    x_val, x_mask_val, y_val, y_mask_val = dataset.valid(validsize)

    # Model & Optimizer definition
    generator_xy = Generator()
    generator_xy.to_gpu()
    gen_xy_opt = set_optimizer(generator_xy, lr_gen, beta1, beta2)

    generator_yx = Generator()
    generator_yx.to_gpu()
    gen_yx_opt = set_optimizer(generator_yx, lr_gen, beta1, beta2)

    discriminator_y = Discriminator()
    discriminator_y.to_gpu()
    dis_y_opt = set_optimizer(discriminator_y, lr_dis, beta1, beta2)

    discriminator_x = Discriminator()
    discriminator_x.to_gpu()
    dis_x_opt = set_optimizer(discriminator_x, lr_dis, beta1, beta2)

    # Loss Function definition
    lossfunc = InstaGANLossFunction()

    # Visualizer definition
    visualize = Visualizer()

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0

        for batch in range(0, iterations, batchsize):
            x, x_mask, y, y_mask = dataset.train(batchsize)

            # discriminator update
            xy, xy_mask = generator_xy(x, x_mask)
            yx, yx_mask = generator_yx(y, y_mask)

            xy.unchain_backward()
            xy_mask.unchain_backward()
            yx.unchain_backward()
            yx_mask.unchain_backward()

            dis_loss = lossfunc.adversarial_dis_loss(discriminator_y, xy,
                                                     xy_mask, y, y_mask)
            dis_loss += lossfunc.adversarial_dis_loss(discriminator_x, yx,
                                                      yx_mask, x, x_mask)

            discriminator_y.cleargrads()
            discriminator_x.cleargrads()
            dis_loss.backward()
            dis_y_opt.update()
            dis_x_opt.update()

            sum_dis_loss += dis_loss.data

            # generator update
            xy, xy_mask = generator_xy(x, x_mask)
            yx, yx_mask = generator_yx(y, y_mask)

            xyx, xyx_mask = generator_yx(xy, xy_mask)
            yxy, yxy_mask = generator_xy(yx, yx_mask)

            x_id, x_mask_id = generator_yx(x, x_mask)
            y_id, y_mask_id = generator_xy(y, y_mask)

            gen_loss = lossfunc.adversarial_gen_loss(discriminator_y, xy,
                                                     xy_mask)
            gen_loss += lossfunc.adversarial_gen_loss(discriminator_x, yx,
                                                      yx_mask)

            gen_loss += lossfunc.cycle_consistency_loss(
                xyx, xyx_mask, x, x_mask)
            gen_loss += lossfunc.cycle_consistency_loss(
                yxy, yxy_mask, y, y_mask)

            gen_loss += lossfunc.identity_mapping_loss(x_id, x_mask_id, x,
                                                       x_mask)
            gen_loss += lossfunc.identity_mapping_loss(y_id, y_mask_id, y,
                                                       y_mask)

            gen_loss += lossfunc.context_preserving_loss(
                xy, xy_mask, x, x_mask)
            gen_loss += lossfunc.context_preserving_loss(
                yx, yx_mask, y, y_mask)

            generator_xy.cleargrads()
            generator_yx.cleargrads()
            gen_loss.backward()
            gen_xy_opt.update()
            gen_yx_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model",
                                     generator_xy)
                serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model",
                                     generator_yx)

                xy, xy_mask = generator_xy(x_val, x_mask_val)
                yx, yx_mask = generator_yx(y_val, y_mask_val)

                x = x_val.data.get()
                x_mask = x_mask_val.data.get()
                xy = xy.data.get()
                xy_mask = xy_mask.data.get()

                visualize(x,
                          x_mask,
                          xy,
                          xy_mask,
                          outdir,
                          epoch,
                          validsize,
                          switch="mtot")

                y = y_val.data.get()
                y_mask = y_mask_val.data.get()
                yx = yx.data.get()
                yx_mask = yx_mask.data.get()

                visualize(y,
                          y_mask,
                          yx,
                          yx_mask,
                          outdir,
                          epoch,
                          validsize,
                          switch="ttom")

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
Exemple #5
0
def train(epochs,
          iterations,
          batchsize,
          validsize,
          outdir,
          modeldir,
          src_path,
          tgt_path,
          extension,
          img_size,
          learning_rate,
          beta1
          ):

    # Dataset definition
    dataloader = DatasetLoader(src_path, tgt_path, extension, img_size)
    print(dataloader)
    src_val = dataloader.valid(validsize)

    # Model & Optimizer definition
    generator_xy = Generator()
    generator_xy.to_gpu()
    gen_xy_opt = set_optimizer(generator_xy, learning_rate, beta1)

    generator_yx = Generator()
    generator_yx.to_gpu()
    gen_yx_opt = set_optimizer(generator_yx, learning_rate, beta1)

    discriminator_y = Discriminator()
    discriminator_y.to_gpu()
    dis_y_opt = set_optimizer(discriminator_y, learning_rate, beta1)

    discriminator_x = Discriminator()
    discriminator_x.to_gpu()
    dis_x_opt = set_optimizer(discriminator_x, learning_rate, beta1)

    # LossFunction definition
    lossfunc = CycleGANLossCalculator()

    # Visualization
    visualizer = Visualization()

    for epoch in range(epochs):
        sum_gen_loss = 0
        sum_dis_loss = 0
        for batch in range(0, iterations, batchsize):
            x, y = dataloader.train(batchsize)

            # Discriminator update
            xy = generator_xy(x)
            yx = generator_yx(y)

            xy.unchain_backward()
            yx.unchain_backward()

            dis_loss_xy = lossfunc.dis_loss(discriminator_y, xy, y)
            dis_loss_yx = lossfunc.dis_loss(discriminator_x, yx, x)

            dis_loss = dis_loss_xy + dis_loss_yx

            discriminator_x.cleargrads()
            discriminator_y.cleargrads()
            dis_loss.backward()
            dis_x_opt.update()
            dis_y_opt.update()

            sum_dis_loss += dis_loss.data

            # Generator update
            xy = generator_xy(x)
            yx = generator_yx(y)

            xyx = generator_yx(xy)
            yxy = generator_xy(yx)

            y_id = generator_xy(y)
            x_id = generator_yx(x)

            # adversarial loss
            gen_loss_xy = lossfunc.gen_loss(discriminator_y, xy)
            gen_loss_yx = lossfunc.gen_loss(discriminator_x, yx)

            # cycle-consitency loss
            cycle_y = lossfunc.cycle_consitency_loss(yxy, y)
            cycle_x = lossfunc.cycle_consitency_loss(xyx, x)

            # identity mapping loss
            identity_y = lossfunc.identity_mapping_loss(y_id, y)
            identity_x = lossfunc.identity_mapping_loss(x_id, x)

            gen_loss = gen_loss_xy + gen_loss_yx + cycle_x + cycle_y + identity_x + identity_y

            generator_xy.cleargrads()
            generator_yx.cleargrads()
            gen_loss.backward()
            gen_xy_opt.update()
            gen_yx_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy)
                serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx)

                with chainer.using_config('train', False):
                    tgt = generator_xy(src_val)

                src = src_val.data.get()
                tgt = tgt.data.get()

                visualizer(src, tgt, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(F"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}")
Exemple #6
0
def train(epochs, iterations, batchsize, validsize, outdir, modeldir,
          data_path, extension, img_size, latent_dim, learning_rate, beta1,
          beta2, enable):

    # Dataset Definition
    dataloader = DataLoader(data_path, extension, img_size, latent_dim)
    print(dataloader)
    color_valid, line_valid = dataloader(validsize, mode="valid")
    noise_valid = dataloader.noise_generator(validsize)

    # Model Definition
    if enable:
        encoder = Encoder()
        encoder.to_gpu()
        enc_opt = set_optimizer(encoder)

    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator, learning_rate, beta1, beta2)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2)

    # Loss Funtion Definition
    lossfunc = GauGANLossFunction()

    # Evaluation Definition
    evaluator = Evaluaton()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            color, line = dataloader(batchsize)
            z = dataloader.noise_generator(batchsize)

            # Discriminator update
            if enable:
                mu, sigma = encoder(color)
                z = F.gaussian(mu, sigma)
            y = generator(z, line)

            y.unchain_backward()

            dis_loss = lossfunc.dis_loss(discriminator, F.concat([y, line]),
                                         F.concat([color, line]))

            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()
            dis_loss.unchain_backward()

            sum_dis_loss += dis_loss.data

            # Generator update
            z = dataloader.noise_generator(batchsize)

            if enable:
                mu, sigma = encoder(color)
                z = F.gaussian(mu, sigma)
            y = generator(z, line)

            gen_loss = lossfunc.gen_loss(discriminator, F.concat([y, line]),
                                         F.concat([color, line]))
            gen_loss += lossfunc.content_loss(y, color)

            if enable:
                gen_loss += 0.05 * F.gaussian_kl_divergence(mu,
                                                            sigma) / batchsize

            generator.cleargrads()
            if enable:
                encoder.cleargrads()
            gen_loss.backward()
            gen_opt.update()
            if enable:
                enc_opt.update()
            gen_loss.unchain_backward()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

                with chainer.using_config("train", False):
                    y = generator(noise_valid, line_valid)
                y = y.data.get()
                sr = line_valid.data.get()
                cr = color_valid.data.get()

                evaluator(y, cr, sr, outdir, epoch, validsize=validsize)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
Exemple #7
0
def train(epochs, iterations, batchsize, data_path, modeldir, extension,
          img_size, learning_rate, beta1, weight_decay):

    # Dataset definition
    dataset = DatasetLoader(data_path, extension, img_size)

    # Model & Optimizer definition
    generator = Generator(dataset.number)
    generator.to_gpu()
    gen_opt = set_optimizer(generator, learning_rate, beta1, weight_decay)

    discriminator = Discriminator(dataset.number)
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay)

    # Loss Function definition
    lossfunc = RelGANLossFunction()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x, x_label, y, y_label, z, z_label = dataset.train(batchsize)

            # Discriminator update
            # Adversairal loss
            a = y_label - x_label
            fake = generator(x, a)
            fake.unchain_backward()
            loss = lossfunc.adversarial_loss_dis(discriminator, fake, y)

            # Interpolation loss
            rnd = np.random.randint(2)
            if rnd == 0:
                alpha = xp.random.uniform(0, 0.5, size=batchsize)
            else:
                alpha = xp.random.uniform(0.5, 1.0, size=batchsize)
            alpha = chainer.as_variable(alpha.astype(xp.float32))
            alpha = F.tile(F.expand_dims(alpha, axis=1), (1, dataset.number))

            fake_0 = generator(x, y_label - y_label)
            fake_1 = generator(x, alpha * a)
            fake_0.unchain_backward()
            fake_1.unchain_backward()
            loss += 10 * lossfunc.interpolation_loss_dis(
                discriminator, fake_0, fake, fake_1, alpha, rnd)

            # Matching loss
            v2 = y_label - z_label
            v3 = z_label - x_label

            loss += lossfunc.matching_loss_dis(discriminator, x, fake, y, z, a,
                                               v2, v3)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            sum_dis_loss += loss.data

            # Generator update
            # Adversarial loss
            fake = generator(x, a)
            loss = lossfunc.adversarial_loss_gen(discriminator, fake)

            # Interpolation loss
            rnd = np.random.randint(2)
            if rnd == 0:
                alpha = xp.random.uniform(0, 0.5, size=batchsize)
            else:
                alpha = xp.random.uniform(0.5, 1.0, size=batchsize)
            alpha = chainer.as_variable(alpha.astype(xp.float32))
            alpha = F.tile(F.expand_dims(alpha, axis=1), (1, dataset.number))

            fake_alpha = generator(x, alpha * a)
            loss += 10 * lossfunc.interpolation_loss_gen(
                discriminator, fake_alpha)

            # Matching loss
            loss += lossfunc.matching_loss_gen(discriminator, x, fake, a)

            # Cycle-consistency loss
            cyc = generator(fake, -a)
            loss += 10 * F.mean_absolute_error(cyc, x)

            # Self-reconstruction loss
            fake_0 = generator(x, y_label - y_label)
            loss += 10 * F.mean_absolute_error(fake_0, x)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_gen_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

        print(
            f"epoch: {epoch} disloss: {sum_dis_loss/iterations} genloss: {sum_gen_loss/iterations}"
        )
def train(epochs, iterations, batchsize, testsize, img_path, seg_path, outdir,
          modeldir, n_dis, mode):
    # Dataset Definition
    dataloader = DatasetLoader(img_path, seg_path)
    print(dataloader)
    valid_noise = dataloader.test(testsize)

    # Model & Optimizer Definition
    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    # Loss Function Definition
    lossfunc = SGANLossFunction()

    # Evaluation Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            for _ in range(n_dis):
                t, s, noise = dataloader.train(batchsize)
                y_img, y_seg = generator(noise)

                loss = lossfunc.dis_loss(discriminator, y_img, y_seg, t, s)
                loss += lossfunc.gradient_penalty(discriminator,
                                                  y_img,
                                                  y_seg,
                                                  t,
                                                  s,
                                                  mode=mode)

                discriminator.cleargrads()
                loss.backward()
                dis_opt.update()
                loss.unchain_backward()

            _, _, noise = dataloader.train(batchsize)
            y_img, y_seg = generator(noise)

            loss = lossfunc.gen_loss(discriminator, y_img, y_seg)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss = loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)
                serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model",
                                     discriminator)

                with chainer.using_config('train', False):
                    y_img, y_seg = generator(valid_noise)
                y_img = y_img.data.get()
                y_seg = y_seg.data.get()

                evaluator(y_img, y_seg, epoch, outdir, testsize=testsize)

        print(f"epoh: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemple #9
0
        x_y.unchain_backward()
        y_x.unchain_backward()

        dis_fake_xy = discriminator_xy(x_y)
        dis_real_xy = discriminator_xy(t)
        #dis_loss_xy=F.mean(F.softplus(dis_fake_xy))+F.mean(F.softplus(-dis_real_xy))
        dis_loss_xy = least_square_loss(dis_fake_xy, dis_real_xy)

        dis_fake_yx = discriminator_yx(y_x)
        dis_real_yx = discriminator_yx(x)
        #dis_loss_yx=F.mean(F.softplus(dis_fake_yx))+F.mean(F.softplus(-dis_real_yx))
        dis_loss_yx = least_square_loss(dis_fake_yx, dis_real_yx)

        dis_loss = dis_loss_xy + dis_loss_yx

        discriminator_xy.cleargrads()
        discriminator_yx.cleargrads()
        dis_loss.backward()
        dis_opt_xy.update()
        dis_opt_yx.update()
        dis_loss.unchain_backward()

        x_y = generator_xy(x)
        x_y_x = generator_yx(x_y)

        y_x = generator_yx(t)
        y_x_y = generator_xy(y_x)

        dis_fake_xy = discriminator_xy(x_y)
        dis_fake_yx = discriminator_yx(y_x)
def train(epochs,
          iterations,
          dataset_path,
          test_path,
          outdir,
          batchsize,
          testsize,
          recon_weight,
          fm_weight,
          gp_weight,
          spectral_norm=False):
    # Dataset Definition
    dataloader = DatasetLoader(dataset_path, test_path)
    c_valid, s_valid = dataloader.test(testsize)

    # Model & Optimizer Definition
    if spectral_norm:
        generator = SNGenerator()
    else:
        generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    # Loss Function Definition
    lossfunc = FUNITLossFunction()

    # Evaluator Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            c, ci, s, si = dataloader.train(batchsize)

            y = generator(c, s)
            y.unchain_backward()

            loss = lossfunc.dis_loss(discriminator, y, s, si)
            loss += lossfunc.gradient_penalty(discriminator, s, y, si)

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            y_conert = generator(c, s)
            y_recon = generator(c, c)

            adv_loss, recon_loss, fm_loss = lossfunc.gen_loss(
                discriminator, y_conert, y_recon, s, c, si, ci)
            loss = adv_loss + recon_weight * recon_loss + fm_weight * fm_loss

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz('generator.model', generator)
                serializers.save_npz('discriminator.model', discriminator)

                with chainer.using_config('train', False):
                    y = generator(c_valid, s_valid)
                y.unchain_backward()

                y = y.data.get()
                c = c_valid.data.get()
                s = s_valid.data.get()

                evaluator(y, c, s, outdir, epoch, testsize)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
def train(epochs, iterations, batchsize, modeldir, extension, time_width,
          mel_bins, sampling_rate, g_learning_rate, d_learning_rate, beta1,
          beta2, identity_epoch, adv_type, residual_flag, data_path):

    # Dataset Definition
    dataloader = DatasetLoader(data_path)

    # Model & Optimizer Definition
    generator = GeneratorWithCIN(adv_type=adv_type)
    generator.to_gpu()
    gen_opt = set_optimizer(generator, g_learning_rate, beta1, beta2)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, d_learning_rate, beta1, beta2)

    # Loss Function Definition
    lossfunc = StarGANVC2LossFunction()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x_sp, x_label, y_sp, y_label = dataloader.train(batchsize)

            if adv_type == 'sat':
                y_fake = generator(x_sp, F.concat([y_label, x_label]))
            elif adv_type == 'orig':
                y_fake = generator(x_sp, y_label)
            else:
                raise AttributeError

            y_fake.unchain_backward()

            if adv_type == 'sat':
                advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss(
                    discriminator, y_fake, x_sp, F.concat([y_label, x_label]),
                    F.concat([x_label, y_label]), residual_flag)
            elif adv_type == 'orig':
                advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss(
                    discriminator, y_fake, x_sp, y_label, x_label,
                    residual_flag)
            else:
                raise AttributeError

            dis_loss = advloss_dis_real + advloss_dis_fake
            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()
            dis_loss.unchain_backward()

            if adv_type == 'sat':
                y_fake = generator(x_sp, F.concat([y_label, x_label]))
                x_fake = generator(y_fake, F.concat([x_label, y_label]))
                x_identity = generator(x_sp, F.concat([x_label, x_label]))
                advloss_gen_fake, cycle_loss = lossfunc.gen_loss(
                    discriminator, y_fake, x_fake, x_sp,
                    F.concat([y_label, x_label]), residual_flag)
            elif adv_type == 'orig':
                y_fake = generator(x_sp, y_label)
                x_fake = generator(y_fake, x_label)
                x_identity = generator(x_sp, x_label)
                advloss_gen_fake, cycle_loss = lossfunc.gen_loss(
                    discriminator, y_fake, x_fake, x_sp, y_label,
                    residual_flag)
            else:
                raise AttributeError

            if epoch < identity_epoch:
                identity_loss = lossfunc.identity_loss(x_identity, x_sp)
            else:
                identity_loss = call_zeros(advloss_dis_fake)

            gen_loss = advloss_gen_fake + cycle_loss + identity_loss
            generator.cleargrads()
            gen_loss.backward()
            gen_opt.update()
            gen_loss.unchain_backward()

            sum_dis_loss += dis_loss.data
            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
            rnd_x = xp.random.uniform(0, 1, x_dis.shape).astype(xp.float32)
            x_perturbed = Variable(cuda.to_gpu(x_dis + 0.5 * rnd_x * std_data))

            x_dis = Variable(cuda.to_gpu(x_dis))
            y_dis = dis_model(x_dis, Variable(t_dis))
            dis_loss += F.mean(F.softplus(-y_dis))

            y_perturbed = dis_model(x_perturbed, Variable(t_dis))
            grad, = chainer.grad([y_perturbed], [x_perturbed],
                                 enable_double_backprop=True)
            grad = F.sqrt(F.batch_l2_norm_squared(grad))
            loss_grad = lambda1 * F.mean_squared_error(grad,
                                                       xp.ones_like(grad.data))
            dis_loss += loss_grad

            dis_model.cleargrads()
            dis_loss.backward()
            dis_loss.unchain_backward()
            dis_opt.update()

        z = Variable(xp.random.normal(size=(batchsize, 128), dtype=xp.float32))
        label = cuda.to_gpu(get_fake_tag_batch(batchsize, dims, threshold))
        z = F.concat([z, Variable(label)])
        x = gen_model(z)
        y = dis_model(x, Variable(label))
        gen_loss = F.mean(F.softplus(-y))

        gen_model.cleargrads()
        gen_loss.backward()
        gen_loss.unchain_backward()
        gen_opt.update()
Exemple #13
0
def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('out')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU device ID')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=200,
                        help='# of epoch')
    parser.add_argument('--batch_size', '-b', type=int, default=10)
    parser.add_argument('--memory_size', '-m', type=int, default=500)
    parser.add_argument('--real_label', type=float, default=0.9)
    parser.add_argument('--fake_label', type=float, default=0.0)
    parser.add_argument('--block_num', type=int, default=6)
    parser.add_argument('--g_nobn',
                        dest='g_bn',
                        action='store_false',
                        default=True)
    parser.add_argument('--d_nobn',
                        dest='d_bn',
                        action='store_false',
                        default=True)
    parser.add_argument('--variable_size', action='store_true', default=False)
    parser.add_argument('--lambda_dis_real', type=float, default=0)
    parser.add_argument('--size', type=int, default=128)
    parser.add_argument('--lambda_', type=float, default=10)

    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()

    # log directory
    out = datetime.datetime.now().strftime('%m%d%H')
    out = out + '_' + args.out
    out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", out))
    os.makedirs(os.path.join(out_dir, 'models'), exist_ok=True)
    os.makedirs(os.path.join(out_dir, 'visualize'), exist_ok=True)

    # hyper parameter
    with open(os.path.join(out_dir, 'setting.txt'), 'w') as f:
        for k, v in args._get_kwargs():
            print('{} = {}'.format(k, v))
            f.write('{} = {}\n'.format(k, v))

    trainA = ImageDataset('horse2zebra/trainA',
                          augmentation=True,
                          image_size=256,
                          final_size=args.size)
    trainB = ImageDataset('horse2zebra/trainB',
                          augmentation=True,
                          image_size=256,
                          final_size=args.size)
    testA = ImageDataset('horse2zebra/testA',
                         image_size=256,
                         final_size=args.size)
    testB = ImageDataset('horse2zebra/testB',
                         image_size=256,
                         final_size=args.size)

    train_iterA = chainer.iterators.MultiprocessIterator(trainA,
                                                         args.batch_size,
                                                         n_processes=min(
                                                             8,
                                                             args.batch_size))
    train_iterB = chainer.iterators.MultiprocessIterator(trainB,
                                                         args.batch_size,
                                                         n_processes=min(
                                                             8,
                                                             args.batch_size))
    N = len(trainA)

    # genA convert B -> A, genB convert A -> B
    genA = Generator(block_num=args.block_num, bn=args.g_bn)
    genB = Generator(block_num=args.block_num, bn=args.g_bn)
    # disA discriminate realA and fakeA, disB discriminate realB and fakeB
    disA = Discriminator(bn=args.d_bn)
    disB = Discriminator(bn=args.d_bn)

    if args.gpu >= 0:
        cuda.get_device_from_id(args.gpu).use()
        genA.to_gpu()
        genB.to_gpu()
        disA.to_gpu()
        disB.to_gpu()

    optimizer_genA = chainer.optimizers.Adam(alpha=0.0002,
                                             beta1=0.5,
                                             beta2=0.9)
    optimizer_genB = chainer.optimizers.Adam(alpha=0.0002,
                                             beta1=0.5,
                                             beta2=0.9)
    optimizer_disA = chainer.optimizers.Adam(alpha=0.0002,
                                             beta1=0.5,
                                             beta2=0.9)
    optimizer_disB = chainer.optimizers.Adam(alpha=0.0002,
                                             beta1=0.5,
                                             beta2=0.9)

    optimizer_genA.setup(genA)
    optimizer_genB.setup(genB)
    optimizer_disA.setup(disA)
    optimizer_disB.setup(disB)

    # start training
    start = time.time()
    fake_poolA = np.zeros(
        (args.memory_size, 3, args.size, args.size)).astype('float32')
    fake_poolB = np.zeros(
        (args.memory_size, 3, args.size, args.size)).astype('float32')
    lambda_ = args.lambda_
    const_realA = np.asarray([testA.get_example(i) for i in range(10)])
    const_realB = np.asarray([testB.get_example(i) for i in range(10)])

    iterations = 0
    for epoch in range(args.epoch):

        if epoch > 100:
            decay_rate = 0.0002 / 100
            optimizer_genA.alpha -= decay_rate
            optimizer_genB.alpha -= decay_rate
            optimizer_disA.alpha -= decay_rate
            optimizer_disB.alpha -= decay_rate

        # train
        iter_num = N // args.batch_size
        for i in range(iter_num):

            # load real batch
            imagesA = train_iterA.next()
            imagesB = train_iterB.next()
            if args.variable_size:
                crop_size = np.random.choice([160, 192, 224, 256])
                resize_size = np.random.choice([160, 192, 224, 256])
                imagesA = [
                    random_augmentation(image, crop_size, resize_size)
                    for image in imagesA
                ]
                imagesB = [
                    random_augmentation(image, crop_size, resize_size)
                    for image in imagesB
                ]
            realA = chainer.Variable(genA.xp.asarray(imagesA, 'float32'))
            realB = chainer.Variable(genB.xp.asarray(imagesB, 'float32'))

            # load fake batch
            if iterations < args.memory_size:
                fakeA = genA(realB)
                fakeB = genB(realA)
                fakeA.unchain_backward()
                fakeB.unchain_backward()
            else:
                fake_imagesA = fake_poolA[np.random.randint(
                    args.memory_size, size=args.batch_size)]
                fake_imagesB = fake_poolB[np.random.randint(
                    args.memory_size, size=args.batch_size)]
                if args.variable_size:
                    fake_imagesA = [
                        random_augmentation(image, crop_size, resize_size)
                        for image in fake_imagesA
                    ]
                    fake_imagesB = [
                        random_augmentation(image, crop_size, resize_size)
                        for image in fake_imagesB
                    ]
                fakeA = chainer.Variable(genA.xp.asarray(fake_imagesA))
                fakeB = chainer.Variable(genA.xp.asarray(fake_imagesB))

            ############################
            # (1) Update D network
            ###########################
            # dis A
            y_realA = disA(realA)
            y_fakeA = disA(fakeA)
            loss_disA = (F.sum((y_realA - args.real_label) ** 2) + F.sum((y_fakeA - args.fake_label) ** 2)) \
                        / np.prod(y_fakeA.shape)

            # dis B
            y_realB = disB(realB)
            y_fakeB = disB(fakeB)
            loss_disB = (F.sum((y_realB - args.real_label) ** 2) + F.sum((y_fakeB - args.fake_label) ** 2)) \
                        / np.prod(y_fakeB.shape)

            # discriminate real A and real B not only realA and fakeA
            if args.lambda_dis_real > 0:
                y_realB = disA(realB)
                loss_disA += F.sum(
                    (y_realB - args.fake_label)**2) / np.prod(y_realB.shape)
                y_realA = disB(realA)
                loss_disB += F.sum(
                    (y_realA - args.fake_label)**2) / np.prod(y_realA.shape)

            # update dis
            disA.cleargrads()
            disB.cleargrads()
            loss_disA.backward()
            loss_disB.backward()
            optimizer_disA.update()
            optimizer_disB.update()

            ###########################
            # (2) Update G network
            ###########################

            # gan A
            fakeA = genA(realB)
            y_fakeA = disA(fakeA)
            loss_ganA = F.sum(
                (y_fakeA - args.real_label)**2) / np.prod(y_fakeA.shape)

            # gan B
            fakeB = genB(realA)
            y_fakeB = disB(fakeB)
            loss_ganB = F.sum(
                (y_fakeB - args.real_label)**2) / np.prod(y_fakeB.shape)

            # rec A
            recA = genA(fakeB)
            loss_recA = F.mean_absolute_error(recA, realA)

            # rec B
            recB = genB(fakeA)
            loss_recB = F.mean_absolute_error(recB, realB)

            # gen loss
            loss_gen = loss_ganA + loss_ganB + lambda_ * (loss_recA +
                                                          loss_recB)
            # loss_genB = loss_ganB + lambda_ * (loss_recB + loss_recA)

            # update gen
            genA.cleargrads()
            genB.cleargrads()
            loss_gen.backward()
            # loss_genB.backward()
            optimizer_genA.update()
            optimizer_genB.update()

            # logging
            logger.plot('loss dis A', float(loss_disA.data))
            logger.plot('loss dis B', float(loss_disB.data))
            logger.plot('loss rec A', float(loss_recA.data))
            logger.plot('loss rec B', float(loss_recB.data))
            logger.plot('loss gen A', float(loss_gen.data))
            # logger.plot('loss gen B', float(loss_genB.data))
            logger.tick()

            # save to replay buffer
            fakeA = cuda.to_cpu(fakeA.data)
            fakeB = cuda.to_cpu(fakeB.data)
            for k in range(args.batch_size):
                fake_sampleA = fakeA[k]
                fake_sampleB = fakeB[k]
                if args.variable_size:
                    fake_sampleA = cv2.resize(
                        fake_sampleA.transpose(1, 2, 0), (256, 256),
                        interpolation=cv2.INTER_AREA).transpose(2, 0, 1)
                    fake_sampleB = cv2.resize(
                        fake_sampleB.transpose(1, 2, 0), (256, 256),
                        interpolation=cv2.INTER_AREA).transpose(2, 0, 1)
                fake_poolA[(iterations * args.batch_size) % args.memory_size +
                           k] = fake_sampleA
                fake_poolB[(iterations * args.batch_size) % args.memory_size +
                           k] = fake_sampleB

            iterations += 1
            progress_report(iterations, start, args.batch_size)

        if epoch % 5 == 0:
            logger.flush(out_dir)
            visualize(genA,
                      genB,
                      const_realA,
                      const_realB,
                      epoch=epoch,
                      savedir=os.path.join(out_dir, 'visualize'))

            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.disA.model".format(epoch)), disA)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.disB.model".format(epoch)), disB)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.genA.model".format(epoch)), genA)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.genB.model".format(epoch)), genB)
Exemple #14
0
def train(epochs, iterations, batchsize, validsize, outdir, modeldir,
          extension, train_size, valid_size, data_path, sketch_path, digi_path,
          learning_rate, beta1, weight_decay):

    # Dataset definition
    dataset = DatasetLoader(data_path, sketch_path, digi_path, extension,
                            train_size, valid_size)
    print(dataset)
    x_val, t_val = dataset.valid(validsize)

    # Model & Optimizer definition
    unet = UNet()
    unet.to_gpu()
    unet_opt = set_optimizer(unet, learning_rate, beta1, weight_decay)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay)

    # Loss function definition
    lossfunc = Pix2pixLossCalculator()

    # Visualization definition
    visualizer = Visualizer()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x, t = dataset.train(batchsize)

            # Discriminator update
            y = unet(x)
            y.unchain_backward()

            dis_loss = lossfunc.dis_loss(discriminator, y, t)

            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()

            sum_dis_loss += dis_loss.data

            # Generator update
            y = unet(x)

            gen_loss = lossfunc.gen_loss(discriminator, y)
            gen_loss += lossfunc.content_loss(y, t)

            unet.cleargrads()
            gen_loss.backward()
            unet_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/unet_{epoch}.model", unet)

                with chainer.using_config("train", False):
                    y = unet(x_val)

                x = x_val.data.get()
                t = t_val.data.get()
                y = y.data.get()

                visualizer(x, t, y, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}"
        )
def train(epochs, iterations, batchsize, testsize, outdir, modeldir, n_dis,
          img_path, tag_path):
    # Dataset Definition
    dataloader = DatasetLoader(img_path, tag_path)
    zvis_valid, ztag_valid = dataloader.valid(batchsize)
    noise_valid = F.concat([zvis_valid, ztag_valid])

    # Model & Optimizer Definition
    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    # Loss Functio Definition
    lossfunc = RGANLossFunction()

    # Evaluation
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            for _ in range(n_dis):
                zvis, ztag, img, tag = dataloader.train(batchsize)

                y = generator(F.concat([zvis, ztag]))
                y.unchain_backward()

                loss = lossfunc.dis_loss(discriminator, y, img, tag, ztag)
                loss += lossfunc.gradient_penalty(discriminator, img, tag)

                discriminator.cleargrads()
                loss.backward()
                dis_opt.update()
                loss.unchain_backward()

            zvis, ztag, _, _ = dataloader.train(batchsize)

            y = generator(F.concat([zvis, ztag]))

            loss = lossfunc.gen_loss(discriminator, y, ztag)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)
                serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator)

                with chainer.using_config('train', False):
                    y = generator(noise_valid)
                y = y.data.get()

                evaluator(y, outdir, epoch, testsize)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Exemple #16
0
def main(args):

    #initialize models and load mnist dataset
    G = Generator()
    D = Discriminator()
    x = load_dataset()

    #build optimizer of generator
    opt_generator = chainer.optimizers.Adam().setup(G)
    opt_generator.use_cleargrads()

    #build optimizer of discriminator
    opt_discriminator = chainer.optimizers.Adam().setup(D)
    opt_generator.use_cleargrads()

    #make the output folder
    if not os.path.exists(args.output):
        os.makedirs(args.output, exist_ok=True)

    #list of loss
    Glosses = []
    Dlosses = []

    print("Now starting training loop...")

    #begin training process
    for train_iter in range(1, args.num_epochs + 1):

        for i in range(0, len(x), 100):

            #Clears all gradient arrays.
            #The following should be called before the backward computation at every iteration of the optimization.
            G.cleargrads()
            D.cleargrads()

            #Train the generator
            noise_samples = sample(100)
            Gloss = 0.5 * F.sum(F.square(D(G(np.asarray(noise_samples))) - 1))
            Gloss.backward()
            opt_generator.update()

            #As above
            G.cleargrads()
            D.cleargrads()

            #Train the discriminator
            noise_samples = sample(100)
            Dreal = D(np.asarray(x[i:i + 100]))
            Dgen = D(G(np.asarray(noise_samples)))
            Dloss = 0.5 * F.sum(F.square(
                (Dreal - 1.0))) + 0.5 * F.sum(F.square(Dgen))
            Dloss.backward()
            opt_discriminator.update()

        #save loss from each batch
        Glosses.append(Gloss.data)
        Dlosses.append(Dloss.data)

        if train_iter % 10 == 0:

            print("epoch {0:04d}".format(train_iter), end=", ")
            print("Gloss: {}".format(Gloss.data), end=", ")
            print("Dloss: {}".format(Dloss.data))

            noise_samples = sample(100)
            print_sample(
                os.path.join(args.output,
                             "epoch_{0:04}.png".format(train_iter)),
                noise_samples, G)

    print("The training process is finished.")

    plotLoss(train_iter, Dlosses, Glosses)
Exemple #17
0
        fake.unchain_backward()
        fake_2.unchain_backward()
        fake_4.unchain_backward()

        # LSGAN
        #adver_loss=0.5*(F.sum((dis_color-1.0)**2)+F.sum(dis_fake**2))/batchsize
        #adver_loss+=0.5*(F.sum((dis2_color-1.0)**2)+F.sum(dis2_fake**2))/batchsize
        #adver_loss+=0.5*(F.sum((dis4_color-1.0)**2)+F.sum(dis4_fake**2))/batchsize

        # DCGAN
        adver_loss = F.mean(F.softplus(-dis_color)) + F.mean(F.softplus(dis_fake))
        adver_loss+=F.mean(F.softplus(-dis2_color)) + F.mean(F.softplus(dis2_fake))
        adver_loss+=F.mean(F.softplus(-dis4_color)) + F.mean(F.softplus(dis4_fake))

        discriminator.cleargrads()
        discriminator_2.cleargrads()
        discriminator_4.cleargrads()
        adver_loss.backward()
        dis_opt.update()
        dis2_opt.update()
        dis4_opt.update()
        adver_loss.unchain_backward()

        fake,_=global_generator(line)
        fake_2=F.average_pooling_2d(fake,3,2,1)
        fake_4=F.average_pooling_2d(fake_2,3,2,1)

        dis_fake,fake_feat=discriminator(F.concat([line,fake]))
        dis2_fake,fake_feat2=discriminator_2(F.concat([line_2,fake_2]))
        dis4_fake,fake_feat3=discriminator_4(F.concat([line_4,fake_4]))
Exemple #18
0
def train_refine(epochs,
                 iterations,
                 batchsize,
                 validsize,
                 data_path,
                 sketch_path,
                 digi_path,
                 st_path,
                 extension,
                 img_size,
                 crop_size,
                 outdir,
                 modeldir,
                 adv_weight,
                 enf_weight):

    # Dataset Definition
    dataloader = RefineDataset(data_path, sketch_path, digi_path, st_path,
                               extension=extension, img_size=img_size, crop_size=crop_size)
    print(dataloader)
    color_valid, line_valid, mask_valid, ds_valid, cm_valid = dataloader(validsize, mode="valid")

    # Model & Optimizer Definition
    generator = SAGeneratorWithGuide(attn_type="sa", base=64, bn=True, activ=F.relu)
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    vgg = VGG()
    vgg.to_gpu()
    vgg_opt = set_optimizer(vgg)
    vgg.base.disable_update()

    # Loss Function Definition
    lossfunc = LossCalculator()

    # Evaluation Definition
    evaluator = Evaluation()

    iteration = 0

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            iteration += 1
            color, line, mask, mask_ds, color_mask = dataloader(batchsize)
            line_input = F.concat([line, mask])

            extractor = vgg(color_mask, extract=True)
            extractor = F.average_pooling_2d(extractor, 3, 2, 1)
            extractor.unchain_backward()

            # Discriminator update
            fake, _ = generator(line_input, mask_ds, extractor)
            y_dis = discriminator(fake, extractor)
            t_dis = discriminator(color, extractor)
            loss = adv_weight * lossfunc.dis_hinge_loss(y_dis, t_dis)

            fake.unchain_backward()

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            sum_dis_loss += loss.data

            # Generator update
            fake, guide = generator(line_input, mask_ds, extractor)
            y_dis = discriminator(fake, extractor)

            loss = adv_weight * lossfunc.gen_hinge_loss(y_dis)
            loss += lossfunc.content_loss(fake, color)
            loss += 0.9 * lossfunc.content_loss(guide, color)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_gen_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)

                extractor = vgg(cm_valid, extract=True)
                extractor = F.average_pooling_2d(extractor, 3, 2, 1)
                extractor.unchain_backward()
                line_valid_input = F.concat([line_valid, mask_valid])

                with chainer.using_config('train', False):
                    y_valid, guide_valid = generator(line_valid_input, ds_valid, extractor)

                y_valid = y_valid.data.get()
                c_valid = color_valid.data.get()
                input_valid = line_valid_input.data.get()
                cm_val = cm_valid.data.get()
                guide_valid = guide_valid.data.get()
                input_valid = np.concatenate([input_valid[:, 3:6], cm_val], axis=1)

                evaluator(y_valid, c_valid, input_valid, guide_valid, outdir, epoch, validsize)

            print(f"iter: {iteration} dis loss: {sum_dis_loss} gen loss: {gen_loss}")
Exemple #19
0
class Train:
    def __init__(self):
        self.data = dataset()
        self.data.reset()
        self.reset()
        # self.load(1)
        self.setLR()
        self.time = time.time()
        self.dataRate = xp.float32(0.8)
        self.mado = xp.hanning(442).astype(xp.float32)
        # n=10
        # load_npz(f"param/gen/gen_{n}.npz",self.generator)
        # load_npz(f"param/dis/dis_{n}.npz",self.discriminator)
        self.training(batchsize=6)

    def reset(self):
        self.generator = None
        self.discriminator = None
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.generator.to_gpu()
        self.discriminator.to_gpu()

    def setLR(self, lr=0.002):
        self.gen_opt = optimizers.Adam(alpha=lr)
        self.gen_opt.setup(self.generator)
        self.gen_opt.add_hook(optimizer.WeightDecay(0.0001))
        self.dis_opt = optimizers.Adam(alpha=lr)
        self.dis_opt.setup(self.discriminator)
        self.dis_opt.add_hook(optimizer.WeightDecay(0.0001))

    # def save(self, i):
    #     with open(f"param/com/com{i}.pickle", mode='wb') as f:
    #         pickle.dump(self.compressor, f)
    #     with open(f"param/gen/gen{i}.pickle", mode='wb') as f:
    #         pickle.dump(self.generator, f)
    #     with open(f"param/dis/dis{i}.pickle", mode='wb') as f:
    #         pickle.dump(self.discriminator, f)

    # def load(self, i):
    #     with open(f"param/com/com{i}.pickle", mode='rb') as f:
    #         self.compressor = pickle.load(f)
    #     with open(f"param/gen/gen{i}.pickle", mode='rb') as f:
    #         self.generator = pickle.load(f)
    #     with open(f"param/dis/dis{i}.pickle", mode='rb') as f:
    #         self.discriminator = pickle.load(f)

    def encode(self, x):
        # print(x.shape)
        # print(x.shape)
        a, b, c = x.shape
        x = x.reshape(a, 1, c).astype(xp.float32)
        # x = xp.hstack([x[:,:,i:b-440+i:221] for i in range(441)]) * hamming
        x = xp.concatenate([
            x[:, :, :-221].reshape(a, -1, 1, 442), x[:, :, 221:].reshape(
                a, -1, 1, 442)
        ],
                           axis=2).reshape(a, -1, 442) * self.mado
        # print(x)

        x = xp.fft.fft(x, axis=-1)
        # xp.fft.fft(xp.arange(100).reshape(2,5,10),axis=-1)
        x = xp.concatenate(
            [x.real.reshape(a, 1, -1, 442),
             x.imag.reshape(a, 1, -1, 442)],
            axis=1)
        #.reshape(a, 2, -1, 442)
        # xp.concatenate([s.real.reshape(2,5,1,10),s.imag.reshape(2,5,1,10)],axis=2)
        # print(x.shape)
        x = xp.transpose(x, axes=(0, 1, 3, 2))
        # print(x.dtype)
        return x

    def decode(self, x):
        # print(x.shape)
        a, b, c, d = x.shape
        x = x[:, 0] + x[:, 1] * 1j
        # print(x.shape)
        # x = xp.transpose(x.reshape(a, -1, 442), axes=(0,1,3,2))
        # print(x.shape)
        # x = x.reshape(x.shape[0], -1, 442)
        x = xp.transpose(xp.fft.ifft(x, axis=1).real, axes=(0, 2, 1))
        # print(x.shape)
        x /= self.mado
        x = x[:, :-1:2].reshape(a, -1)[:, 221:] + x[:, 1::2].reshape(
            a, -1)[:, :-221]
        # print(x.shape)
        return x

    def training(self, batchsize=1):
        for x in range(100):
            N = self.data.reset()
            # a,b,c=self.data.test()
            # d=F.argmax(self.generator(a.astype(xp.float32),b.astype(xp.int16),c.astype(xp.int16)),-2).data.get().reshape(-1)
            # print(d[25000:26000])
            # self.data.save(d, "_")
            # self.batch(batchsize = 1)
            for i in range(N // batchsize - 1):
                # if not i%1:
                # self.save(i)
                # g=copy.deepcopy(self.generator).to_cpu
                # g.to_cpu
                # print(d[25000:25100])
                res = self.batch(batchsize=batchsize)
                if not i % 10:
                    print(
                        F"{i} time:{int(time.time()-self.time)} G_Loss:{res[0][0]} {res[0][1]} D_Loss:{res[1][0]+res[1][1]} D_Acc:{res[2]}"
                    )
                    if not i % 100:
                        # save_npz(f"param/com/com_{i}.npz",self.compressor)
                        save_npz(f"param/gen/gen_{i}.npz", self.generator)
                        save_npz(f"param/dis/dis_{i}.npz", self.discriminator)
                        a = xp.asarray(self.data.testData[0][:88200].reshape(
                            1, 1, 1, -1))

                        # a=self.encode(a.reshape(1,1,-1)[:,:,:a.shape[-1]//442*442-221])
                        # a=self.encode(a.reshape(1,1,-1)[:,:,:112047])
                        # b=self.encode(b)
                        # c=self.encode(c)
                        d = self.generator(a, xp.array([110])).data.get()
                        # d=self.decode(d).get()
                        # print(d.shape)
                        self.data.save(d.flatten(), f"Garagara_{i}")

                        # del d

                    # print(res[-1][0])
                    # print(res[-1][1])

    def batch(self, batchsize=2):
        x, c = self.data.next(batchSize=batchsize,
                              dataSize=[8190],
                              dataSelect=[0])
        x = x[0].reshape(batchsize, 1, 1, -1)
        c = xp.asarray(c[0])
        c_ = xp.random.randint(0, 111, batchsize)
        c_ = c_ + (c_ >= c)
        # t = next(self.test)
        # t = self.data.test(size=6143)
        # _ = lambda x:self.encode(x)
        # _ = lambda x:x/xp.float32(32768)
        # B0_ = _(B0)
        A_gen = self.generator(x, c_)
        # print(A_gen.shape)
        B_gen = self.generator(x, c)

        F_tf, F_c = self.discriminator(A_gen[:, :, :, 5119:])
        T_tf, T_c = self.discriminator(x[:, :, :, 2047:-5119])

        dis_acc = (F.argmax(F_tf, axis=1).data.sum(),
                   xp.int32(batchsize) - F.argmax(T_tf, axis=1).data.sum(),
                   (T_c.data.argmax(axis=-1) == c).sum())
        # acc = (dis_acc[0]+dis_acc[1])/8

        # self.dataRate = self.dataRate if dis_acc[0] == dis_acc[1] else self.dataRate / xp.float32(0.99) if dis_acc[0] > dis_acc[1] else self.dataRate * xp.float32(0.99)

        # receptionSize = B0.shape[-1] - B_gen.shape[-1]
        # L_gen0 = F.softmax_cross_entropy(B_gen, B0[:,:,receptionSize:].reshape(batchsize,-1))
        # print(B_gen.shape)
        # print(B0_.shape)
        # L_gen0 = 0
        L_gen0 = F.mean_squared_error(B_gen, x[:, :, :, 1023:-1024])
        L_gen1 = F.softmax_cross_entropy(F_tf,
                                         xp.zeros(batchsize, dtype=np.int32))
        L_gen2 = F.softmax_cross_entropy(F_c, c_)
        gen_loss = (L_gen0.data, L_gen1.data)
        L_gen = L_gen1 + L_gen0 + L_gen2
        # L_gen = L_gen1 + (L_gen0 if L_gen0.data > 0.0001 else 0)

        L_dis0 = F.softmax_cross_entropy(F_tf,
                                         xp.ones(batchsize, dtype=np.int32))
        L_dis1 = F.softmax_cross_entropy(T_tf,
                                         xp.zeros(batchsize, dtype=np.int32))
        L_dis2 = F.softmax_cross_entropy(T_c, c)
        dis_loss = (L_dis0.data.get(), L_dis1.data.get(), L_dis2.data.get())
        # L_dis = L_dis0 * min(xp.float32(1), 1 / self.dataRate) + L_dis1 * min(xp.float32(1), self.dataRate)
        L_dis = L_dis0 + L_dis1 + L_dis2

        self.generator.cleargrads()
        L_gen.backward()
        self.gen_opt.update()

        self.discriminator.cleargrads()
        L_dis.backward()
        self.dis_opt.update()

        self.dis_opt.alpha *= 0.99999
        self.gen_opt.alpha *= 0.99999
        return (gen_loss, dis_loss, dis_acc, self.dataRate, (F_tf.data,
                                                             T_tf.data))

    def garagara(self):
        pass
Exemple #20
0
def train(epochs,
          iterations,
          batchsize,
          validsize,
          data_path,
          sketch_path,
          digi_path,
          extension,
          img_size,
          outdir,
          modeldir,
          pretrained_epoch,
          adv_weight,
          enf_weight,
          sn,
          bn,
          activ):

    # Dataset Definition
    dataloader = DataLoader(data_path, sketch_path, digi_path,
                            extension=extension, img_size=img_size)
    print(dataloader)
    color_valid, line_valid, mask_valid, ds_valid = dataloader(validsize, mode="valid")

    # Model & Optimizer Definition
    generator = SAGeneratorWithGuide(attn_type="sa", bn=bn, activ=activ)
    #generator = SAGenerator(attn_type="sa", base=64)
    generator.to_gpu()
    gen_opt = set_optimizer(generator)

    discriminator = Discriminator(sn=sn)
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator)

    vgg = VGG()
    vgg.to_gpu()
    vgg_opt = set_optimizer(vgg)
    vgg.base.disable_update()

    # Loss Function Definition
    lossfunc = LossCalculator()

    # Evaluation Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            color, line, mask, mask_ds = dataloader(batchsize)
            line_input = F.concat([line, mask])

            extractor = vgg(mask, extract=True)
            extractor = F.average_pooling_2d(extractor, 3, 2, 1)
            extractor.unchain_backward()

            if epoch > pretrained_epoch:
                adv_weight = 0.1
                enf_weight = 0.0

            # Discriminator update
            fake, _ = generator(line_input, mask_ds, extractor)
            y_dis = discriminator(fake, extractor)
            t_dis = discriminator(color, extractor)
            loss = adv_weight * lossfunc.dis_hinge_loss(y_dis, t_dis)

            fake.unchain_backward()

            discriminator.cleargrads()
            loss.backward()
            dis_opt.update()
            loss.unchain_backward()

            # Generator update
            fake, guide = generator(line_input, mask_ds, extractor)
            y_dis = discriminator(fake, extractor)

            loss = adv_weight * lossfunc.gen_hinge_loss(y_dis)
            loss += enf_weight * lossfunc.positive_enforcing_loss(fake)
            loss += lossfunc.content_loss(fake, color)
            loss += 0.9 * lossfunc.content_loss(guide, color)
            loss += lossfunc.perceptual_loss(vgg, fake, color)

            generator.cleargrads()
            loss.backward()
            gen_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)

                extractor = vgg(line_valid, extract=True)
                extractor = F.average_pooling_2d(extractor, 3, 2, 1)
                extractor.unchain_backward()
                line_valid_input = F.concat([line_valid, mask_valid])

                with chainer.using_config('train', False):
                    y_valid, guide_valid = generator(line_valid_input, ds_valid, extractor)

                y_valid = y_valid.data.get()
                c_valid = color_valid.data.get()
                input_valid = line_valid_input.data.get()
                guide_valid = guide_valid.data.get()

                evaluator(y_valid, c_valid, input_valid, guide_valid, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")