示例#1
0
    def zero_centered_gradient_penalty_fake(fake, y):
        grad, = chainer.grad([fake], [y], enable_double_backprop=True)
        grad = F.sqrt(F.batch_l2_norm_squared(grad))
        zeros = call_zeros(grad)

        loss = 10 * F.mean_squared_error(grad, zeros)

        return loss
示例#2
0
    def zero_centered_gradient_penalty_real(discriminator, t):
        t = chainer.Variable(t.data)
        real = discriminator(t, y=None, label=None, method="adv")

        grad, = chainer.grad([real], [t], enable_double_backprop=True)
        grad = F.sqrt(F.batch_l2_norm_squared(grad))
        zeros = call_zeros(grad)

        loss = 10 * F.mean_squared_error(grad, zeros)

        return loss
示例#3
0
    def interpolation_loss_dis(discriminator, y0, y1, yalpha, alpha, flag):
        fake_y0 = discriminator(y0, y=None, label=None, method="inp")
        fake_y1 = discriminator(y1, y=None, label=None, method="inp")
        fake_yalpha = discriminator(yalpha, y=None, label=None, method="inp")

        zeros = call_zeros(fake_y0)

        if flag == 0:
            loss = F.mean_squared_error(fake_y0, zeros)
            loss += F.mean_squared_error(fake_yalpha, alpha)

        else:
            loss = F.mean_squared_error(fake_y1, zeros)
            loss += F.mean_squared_error(fake_yalpha, (1 - alpha))

        return loss
示例#4
0
    def matching_loss_dis(discriminator, x, fake, y, z, v1, v2, v3):
        sr = discriminator(x, y, v1, method="mat")
        sf = discriminator(x, fake, v1, method="mat")
        sw0 = discriminator(z, y, v1, method="mat")
        sw1 = discriminator(x, y, v2, method="mat")
        sw2 = discriminator(x, y, v3, method="mat")
        sw3 = discriminator(x, z, v1, method="mat")

        zeros = call_zeros(sr)
        ones = call_ones(sr)

        loss = F.mean_squared_error(sr, ones)
        loss += F.mean_squared_error(sf, zeros)
        loss += F.mean_squared_error(sw0, zeros)
        loss += F.mean_squared_error(sw1, zeros)
        loss += F.mean_squared_error(sw2, zeros)
        loss += F.mean_squared_error(sw3, zeros)

        return loss
示例#5
0
    def interpolation_loss_gen(discriminator, yalpha):
        fake_yalpha = discriminator(yalpha, y=None, label=None, method="inp")

        zeros = call_zeros(fake_yalpha)

        return F.mean_squared_error(fake_yalpha, zeros)
示例#6
0
def train(epochs, iterations, batchsize, modeldir, extension, time_width,
          mel_bins, sampling_rate, g_learning_rate, d_learning_rate, beta1,
          beta2, identity_epoch, adv_type, residual_flag, data_path):

    # Dataset Definition
    dataloader = DatasetLoader(data_path)

    # Model & Optimizer Definition
    generator = GeneratorWithCIN(adv_type=adv_type)
    generator.to_gpu()
    gen_opt = set_optimizer(generator, g_learning_rate, beta1, beta2)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, d_learning_rate, beta1, beta2)

    # Loss Function Definition
    lossfunc = StarGANVC2LossFunction()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x_sp, x_label, y_sp, y_label = dataloader.train(batchsize)

            if adv_type == 'sat':
                y_fake = generator(x_sp, F.concat([y_label, x_label]))
            elif adv_type == 'orig':
                y_fake = generator(x_sp, y_label)
            else:
                raise AttributeError

            y_fake.unchain_backward()

            if adv_type == 'sat':
                advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss(
                    discriminator, y_fake, x_sp, F.concat([y_label, x_label]),
                    F.concat([x_label, y_label]), residual_flag)
            elif adv_type == 'orig':
                advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss(
                    discriminator, y_fake, x_sp, y_label, x_label,
                    residual_flag)
            else:
                raise AttributeError

            dis_loss = advloss_dis_real + advloss_dis_fake
            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()
            dis_loss.unchain_backward()

            if adv_type == 'sat':
                y_fake = generator(x_sp, F.concat([y_label, x_label]))
                x_fake = generator(y_fake, F.concat([x_label, y_label]))
                x_identity = generator(x_sp, F.concat([x_label, x_label]))
                advloss_gen_fake, cycle_loss = lossfunc.gen_loss(
                    discriminator, y_fake, x_fake, x_sp,
                    F.concat([y_label, x_label]), residual_flag)
            elif adv_type == 'orig':
                y_fake = generator(x_sp, y_label)
                x_fake = generator(y_fake, x_label)
                x_identity = generator(x_sp, x_label)
                advloss_gen_fake, cycle_loss = lossfunc.gen_loss(
                    discriminator, y_fake, x_fake, x_sp, y_label,
                    residual_flag)
            else:
                raise AttributeError

            if epoch < identity_epoch:
                identity_loss = lossfunc.identity_loss(x_identity, x_sp)
            else:
                identity_loss = call_zeros(advloss_dis_fake)

            gen_loss = advloss_gen_fake + cycle_loss + identity_loss
            generator.cleargrads()
            gen_loss.backward()
            gen_opt.update()
            gen_loss.unchain_backward()

            sum_dis_loss += dis_loss.data
            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )